From 5ccd2a2996a3cffac8dafae2d67b277dad6b75d2 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 22 May 2023 16:00:36 +1200 Subject: [PATCH 01/15] try getting rid of underscore in _auto_generated_list_of_measures.md --- .gitignore | 2 +- README.md | 2 +- docs/make.jl | 2 +- docs/make_tools.jl | 2 +- docs/src/index.md | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 02bf8b1..f8e630c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ Manifest.toml -docs/src/_auto_generated_list_of_measures.md +docs/src/auto_generated_list_of_measures.md .ipynb_checkpoints *~ #* diff --git a/README.md b/README.md index 30c298f..7a48b1e 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Measures (metrics) for statistics and machine learning. [![Coverage](https://codecov.io/gh/JuliaAI/StatisticalMeasures.jl/branch/master/graph/badge.svg)](https://codecov.io/github/JuliaAI/StatisticalMeasures.jl?branch=master) [![Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliaai.github.io/StatisticalMeasures.jl/dev/) -[List](https://juliaai.github.io/StatisticalMeasures.jl/dev/_auto_generated_list_of_measures.html#aliases) of provided measures and their aliases thereof. +[List](https://juliaai.github.io/StatisticalMeasures.jl/dev/auto_generated_list_of_measures.html#aliases) of provided measures and their aliases thereof. Powered by [StatisticalMeasuresBase.jl](https://juliaai.github.io/StatisticalMeasuresBase.jl/dev/). diff --git a/docs/make.jl b/docs/make.jl index 10214dc..99a3542 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -20,7 +20,7 @@ makedocs(; pages=[ "Overview" => "index.md", "Examples of usage" => "examples_of_usage.md", - "The Measures" => "_auto_generated_list_of_measures.md", + "The Measures" => "auto_generated_list_of_measures.md", "Confusion Matrices" => "confusion_matrices.md", "Receiver Operator Characteristics" => "roc.md", "Tools" => "tools.md", diff --git a/docs/make_tools.jl b/docs/make_tools.jl index 9de1314..833c17a 100644 --- a/docs/make_tools.jl +++ b/docs/make_tools.jl @@ -39,7 +39,7 @@ function alias_table() end function write_measures_page(path=PATH_TO_DOCS_SRC) - pagename = "_auto_generated_list_of_measures.md" + pagename = "auto_generated_list_of_measures.md" pagepath = joinpath(path, pagename) traits_given_constructor = measures() all_constructors = keys(traits_given_constructor) |> collect diff --git a/docs/src/index.md b/docs/src/index.md index b422433..514e5ac 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -2,7 +2,7 @@
- List of measures  |  Examples From 4ca4f7c075b9b6bb86f001a2de1082531a5fcde3 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 22 May 2023 17:40:21 +1200 Subject: [PATCH 02/15] fix link again --- README.md | 2 +- docs/src/index.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7a48b1e..04dad75 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Measures (metrics) for statistics and machine learning. [![Coverage](https://codecov.io/gh/JuliaAI/StatisticalMeasures.jl/branch/master/graph/badge.svg)](https://codecov.io/github/JuliaAI/StatisticalMeasures.jl?branch=master) [![Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliaai.github.io/StatisticalMeasures.jl/dev/) -[List](https://juliaai.github.io/StatisticalMeasures.jl/dev/auto_generated_list_of_measures.html#aliases) of provided measures and their aliases thereof. +[List](https://juliaai.github.io/StatisticalMeasures.jl/dev/auto_generated_list_of_measures#aliases) of provided measures and their aliases thereof. Powered by [StatisticalMeasuresBase.jl](https://juliaai.github.io/StatisticalMeasuresBase.jl/dev/). diff --git a/docs/src/index.md b/docs/src/index.md index 514e5ac..617bd83 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -2,7 +2,7 @@
- List of measures  |  Examples From ead885ea8d868bcae291c68219ddadbc94871f8a Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 22 May 2023 17:45:34 +1200 Subject: [PATCH 03/15] add cross ref in docstring for auc to roc_curve --- src/probabilistic.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/probabilistic.jl b/src/probabilistic.jl index bf38359..5be7af0 100644 --- a/src/probabilistic.jl +++ b/src/probabilistic.jl @@ -103,7 +103,8 @@ Core implementation: [`Functions.auc`](@ref). $INVARIANT_LABEL """, - scitype = "" + scitype = "", + footer="See also [`roc_curve`](@ref). ", ) "$AreaUnderCurveDoc" From 4ea640ea35bb80ed7a71ca33902a1a43c0627fd9 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 22 May 2023 17:49:26 +1200 Subject: [PATCH 04/15] another docstring tweak --- src/docstrings.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/docstrings.jl b/src/docstrings.jl index 15bf960..518d320 100644 --- a/src/docstrings.jl +++ b/src/docstrings.jl @@ -112,7 +112,7 @@ function docstring(signature; body="", footer="", scitype="") API.can_report_unaggregated(m) && (ret *= "\n\nMeasurements are aggregated. "* "To obtain a separate measurement for each observation, "* - "use the syntax `measurements(m, ŷ, y)`. ") + "use the syntax `measurements($m_str, ŷ, y)`. ") ret *= "Generally, an observation `obs` in `MLUtils.eachobs(y)` is expected to satisfy "* "`ScientificTypes.scitype(obs)<:`$scitype. " # if kind_of_proxy == LearnAPI.LiteralTarget() From bedd5435f03367894d4bab6de155aadbef90f398 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 22 May 2023 18:06:36 +1200 Subject: [PATCH 05/15] remove "under construction" icon on readme; yay! --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 04dad75..c421f1c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # StatisticalMeasures.jl -🚧 - Measures (metrics) for statistics and machine learning. [![Build Status](https://github.com/JuliaAI/StatisticalMeasures.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/StatisticalMeasures.jl/actions) From c1e1c69625e27d808724b1a65e58a45970f15f55 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 22 May 2023 18:10:55 +1200 Subject: [PATCH 06/15] typo --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 617bd83..c2d773b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -19,7 +19,7 @@ Measures (metrics) for statistics and machine learning # Overview This package defines common measures (metrics) for classification and regression problems -in statistics and machine learning. To see if your favorite measures is implemented, see +in statistics and machine learning. To see if your favorite measure is implemented, see [this list](@ref aliases). Some multi-target measures are included, but see also [Custom multi-target measures](@ref). From 1ce373f8dcff3de32eefe3f720c70c8343a033cc Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 23 May 2023 14:04:38 +1200 Subject: [PATCH 07/15] remove debugging printing --- ext/LossFunctionsExt.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/ext/LossFunctionsExt.jl b/ext/LossFunctionsExt.jl index 3cfa309..fce8367 100644 --- a/ext/LossFunctionsExt.jl +++ b/ext/LossFunctionsExt.jl @@ -31,8 +31,6 @@ loss(measure) = measure # show(io, mime, loss(measure)) # Base.show(io::IO, measure::LossFunctionType) = show(io, loss(measure)) -println("############### LOADED ###############") - # # DISTANCE LOSS TYPE @trait( From 38d1a11314a4a992b87865c7d315b8abae6253c4 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 23 May 2023 14:06:03 +1200 Subject: [PATCH 08/15] bump 0.1.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a8d032d..1a90ddf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StatisticalMeasures" uuid = "a19d573c-0a75-4610-95b3-7071388c7541" authors = ["Anthony D. Blaom "] -version = "0.1.0" +version = "0.1.1" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" From 19ecb30d1dc7bfd74bc93a61297b7f05ecdbbb8f Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 23 May 2023 14:12:16 +1200 Subject: [PATCH 09/15] don't export precision --- src/StatisticalMeasures.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/StatisticalMeasures.jl b/src/StatisticalMeasures.jl index 33c610e..6b8e131 100644 --- a/src/StatisticalMeasures.jl +++ b/src/StatisticalMeasures.jl @@ -49,6 +49,7 @@ end const MEASURES_FOR_EXPORT = let measures = measures() ret = Symbol[] for C in keys(measures) + C === :precision && continue push!(ret, Symbol(C)) for alias in measures[C].aliases push!(ret, Symbol(alias)) From 8ab23e9f168bcfc40f4e07dc97cc25b7d25be5c1 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 23 May 2023 15:39:16 +1200 Subject: [PATCH 10/15] improve show for confusion matrices --- src/confusion_matrices.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/confusion_matrices.jl b/src/confusion_matrices.jl index 081fba6..642a613 100644 --- a/src/confusion_matrices.jl +++ b/src/confusion_matrices.jl @@ -607,6 +607,10 @@ function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrix{N} write(stream, take!(iob)) end +function Base.show(stream::IO, cm::ConfusionMatrix{N}) where N + mat = matrix(cm, warn=false) + print(stream, "ConfusionMatrix{$N}($(repr(mat)))") +end # ## STATISTICAL FUNCTIONS ON CONFUSION MATRICES From e2836df41266feb10dd3804ba48ab5a08ee3d329 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 23 May 2023 16:43:10 +1200 Subject: [PATCH 11/15] add measures(needle) search --- src/registry.jl | 29 ++++++++++++++++++++++++++++- test/registry.jl | 8 +++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/registry.jl b/src/registry.jl index 38dc0e0..0949314 100644 --- a/src/registry.jl +++ b/src/registry.jl @@ -5,7 +5,7 @@ const ERR_BAD_CONSTRUCTOR = ArgumentError( ) """ - StatisticalMeasusures.measures() + measures() *Experimental* and subject to breaking behavior between patch releases. @@ -16,6 +16,33 @@ traits shared by all measures constructed using the syntax `constructor(args...) """ measures() = TRAITS_GIVEN_CONSTRUCTOR +""" + measures(needle::Union{AbstractString,Regex}) + +Return a dictionary keyed on measure constructors that contain `needle` in their document +strings. + +``` +julia> measures("root") +LittleDict{Any, Any, Vector{Any}, Vector{Any}} with 8 entries: + RootMeanSquaredError => (aliases = ("rms", "rmse", "root_mean_squared_error"), c… + MultitargetRootMeanSquaredEr… => (aliases = ("multitarget_rms", "multitarget_rmse", "mult… + RootMeanSquaredLogError => (aliases = ("rmsl", "rmsle", "root_mean_squared_log_erro… + MultitargetRootMeanSquaredLo… => (aliases = ("multitarget_rmsl", "multitarget_rmsle", "mu… + RootMeanSquaredLogProportion… => (aliases = ("rmslp1",), consumes_multiple_observations =… + MultitargetRootMeanSquaredLo… => (aliases = ("multitarget_rmslp1",), consumes_multiple_ob… + RootMeanSquaredProportionalE… => (aliases = ("rmsp",), consumes_multiple_observations = t… + MultitargetRootMeanSquaredPr… => (aliases = ("multitarget_rmsp",), consumes_multiple_obse… +``` + +""" +function measures(needle::Union{AbstractString,Regex}) + filter(measures()) do (constructor, _) + doc = Base.Docs.doc(constructor) |> string + occursin(needle, doc) + end +end + """ StatisticalMeasures.register(constructor, aliases=String[]) diff --git a/test/registry.jl b/test/registry.jl index a100437..13fb575 100644 --- a/test/registry.jl +++ b/test/registry.jl @@ -11,5 +11,11 @@ measure = LPLossOnScalars() @test API.$trait(measure) == getproperty(metadata, $trait_ex) end |> eval end - @test API.measures()[LPLossOnVectors].aliases == ("l2", ) + @test measures()[LPLossOnVectors].aliases == ("l2", ) +end + +@testset "search for needle in docstring" begin + ms = measures("Matthew") + @test [keys(ms)...] == [MatthewsCorrelation,] + @test measures()[MatthewsCorrelation] == ms[MatthewsCorrelation] end From 118ad95fbafdac0b76a5516b2ae3948e9bb1c07c Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 24 May 2023 09:05:21 +1200 Subject: [PATCH 12/15] add measure search by trait value --- src/registry.jl | 50 ++++++++++++++++++++++++++++++++++++++++++------ test/registry.jl | 19 ++++++++++++++++++ 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/src/registry.jl b/src/registry.jl index 0949314..1b03b19 100644 --- a/src/registry.jl +++ b/src/registry.jl @@ -4,21 +4,58 @@ const ERR_BAD_CONSTRUCTOR = ArgumentError( "Constructor must have a zero argument method. " ) +const ERR_BAD_KWARG(trait) = ArgumentError( + "`$trait` is not a valid measure trait name. You must choose one of these: "* + "$(API.OVERLOADABLE_TRAITS_LIST). " +) + """ - measures() + measures(; filter_options...) *Experimental* and subject to breaking behavior between patch releases. Return a dictionary, `dict`, keyed on measure constructors provided by -StatisticalMeasures.jl. The value of `dict[constructor]` provides information about -traits shared by all measures constructed using the syntax `constructor(args...)`. +StatisticalMeasures.jl. The value of `dict[constructor]` provides information about traits +(measure "metadata") shared by all measures constructed using the syntax +`constructor(args...)`. + +# Filter options + +One can filter on the basis of measure trait values, as shown in this example: + +``` +using StatisticalMeasures +using ScientificTypes + +julia> measures( + observation_scitype = Union{Missing,Multiclass}, + supports_class_weights = true, +) +``` + +For more general searches, use a `filter(measures()) do (_, metadata) ... end` +construction. """ -measures() = TRAITS_GIVEN_CONSTRUCTOR +measures(; kwargs...) = filter(TRAITS_GIVEN_CONSTRUCTOR) do (_, metadata) + trait_value_pairs = collect(kwargs) + traits = first.(trait_value_pairs) + for trait in traits + trait in API.OVERLOADABLE_TRAITS || throw(ERR_BAD_KWARG(trait)) + end + all(trait_value_pairs) do pair + trait = first(pair) + value = last(pair) + getproperty(metadata, trait) == value + end +end + """ measures(needle::Union{AbstractString,Regex}) +*Experimental* and subject to breaking behavior between patch releases. + Return a dictionary keyed on measure constructors that contain `needle` in their document strings. @@ -36,13 +73,14 @@ LittleDict{Any, Any, Vector{Any}, Vector{Any}} with 8 entries: ``` """ -function measures(needle::Union{AbstractString,Regex}) - filter(measures()) do (constructor, _) +function measures(needle::Union{AbstractString,Regex}; kwargs...) + filter(measures(; kwargs...)) do (constructor, _) doc = Base.Docs.doc(constructor) |> string occursin(needle, doc) end end + """ StatisticalMeasures.register(constructor, aliases=String[]) diff --git a/test/registry.jl b/test/registry.jl index 13fb575..e70144d 100644 --- a/test/registry.jl +++ b/test/registry.jl @@ -19,3 +19,22 @@ end @test [keys(ms)...] == [MatthewsCorrelation,] @test measures()[MatthewsCorrelation] == ms[MatthewsCorrelation] end + +@testset "search using trait values" begin + ms = measures( + observation_scitype = Union{Missing,Multiclass}, + supports_class_weights = true, + ) + # test filter only catches true matches: + @test all(ms) do (_, metadata) + metadata.observation_scitype == Union{Missing,Multiclass} && + metadata.supports_class_weights + end + # find on basis of a mutually exclusive condition: + ms! = measures( + observation_scitype = Union{Missing,Multiclass}, + supports_class_weights = false, + ) + # check no matches in common: + @test isempty(intersect(keys(ms), keys(ms!))) +end From bd1b22d8981b5cca3249218875ebce8d192834d9 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 24 May 2023 13:55:28 +1200 Subject: [PATCH 13/15] add search by measure signature --- Project.toml | 4 +++ docs/Project.toml | 1 + docs/make.jl | 1 + docs/src/examples_of_usage.md | 22 +++++++++++- docs/src/tools.md | 2 +- ext/ScientificTypesExt.jl | 41 ++++++++++++++++++++++ src/StatisticalMeasures.jl | 3 +- src/registry.jl | 64 +++++++++++++++++++++-------------- src/tools.jl | 1 + test/ScientificTypesExt.jl | 29 ++++++++++++++++ test/runtests.jl | 4 +++ 11 files changed, 144 insertions(+), 28 deletions(-) create mode 100644 ext/ScientificTypesExt.jl create mode 100644 test/ScientificTypesExt.jl diff --git a/Project.toml b/Project.toml index 1a90ddf..688db7f 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -20,9 +21,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" [extensions] LossFunctionsExt = "LossFunctions" +ScientificTypesExt = "ScientificTypes" [compat] CategoricalArrays = "0.10" @@ -33,6 +36,7 @@ LossFunctions = "0.10" MacroTools = "0.5" OrderedCollections = "1" PrecompileTools = "1.1" +ScientificTypes = "3" ScientificTypesBase = "3" StatisticalMeasuresBase = "0.1" StatsBase = "0.33, 0.34" diff --git a/docs/Project.toml b/docs/Project.toml index fb13661..68d36ca 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" diff --git a/docs/make.jl b/docs/make.jl index 99a3542..3aa2dd9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,6 +3,7 @@ using StatisticalMeasures using StatisticalMeasures.StatisticalMeasuresBase using StatisticalMeasures.LearnAPI using ScientificTypesBase +using ScientificTypes const REPO="github.com/JuliaAI/StatisticalMeasures.jl" diff --git a/docs/src/examples_of_usage.md b/docs/src/examples_of_usage.md index 4ab38af..d481716 100644 --- a/docs/src/examples_of_usage.md +++ b/docs/src/examples_of_usage.md @@ -35,7 +35,7 @@ bacc == BalancedAccuracy() == BalancedAccuracy(adjusted=false) - [Probabilistic regression](@ref) - [Custom multi-target measures](@ref) - [Using losses from LossFunctions.jl](@ref) - +- [Measure search (experimental feature)](@ref) ## Binary classification @@ -333,3 +333,23 @@ Wrap again, as shown in the preceding section, to get a multi-target version. For distance-based loss functions, wrapping in `Measure` is not strictly necessary, but does no harm. + + +## Measure search (experimental feature) + +```@example 29 +using StatisticalMeasures +using ScientificTypes + +y = rand(3) +yhat = rand(3) +options = measures(yhat, y, supports_weights=true) +``` + +```@example 29 +options[LPLoss] +``` + +```@example 29 +measures("Matthew") +``` diff --git a/docs/src/tools.md b/docs/src/tools.md index 6762aea..40cceec 100644 --- a/docs/src/tools.md +++ b/docs/src/tools.md @@ -3,7 +3,7 @@ | method | description | |:----------------------------------------------|:-------------------------------------------------------------------------------------| | [`measurements`](@ref)`(measure, ...)` | for obtaining per-observation measurements, instead of aggregated ones | -| [`measures()`](@ref) | dictionary of traits keyed on measure constructors | +| [`measures()`](@ref) | dictionary of traits keyed on measure constructors, with filter options | | [`unfussy(measure)`](@ref) | new measure without argument checks¹ | | [`multimeasure`](@ref)`(measure; options...)` | wrapper to broadcast measures over multiple observations | | [`robust_measure(measure)`](@ref) | wrapper to silently treat unsupported weights as uniform | diff --git a/ext/ScientificTypesExt.jl b/ext/ScientificTypesExt.jl new file mode 100644 index 0000000..29f677e --- /dev/null +++ b/ext/ScientificTypesExt.jl @@ -0,0 +1,41 @@ +module ScientificTypesExt +using ScientificTypes +import ScientificTypes.Tables +using StatisticalMeasures +import StatisticalMeasuresBase.MLUtils +import Distributions +import LearnAPI + +# # HELPERS + +guess_observation_scitype(y) = guess_observation_scitype(y, Val(Tables.istable(y))) +guess_observation_scitype(y, ::Val{false}) = MLUtils.getobs(y, 1) |> scitype +guess_observation_scitype(table, ::Val{true}) = + MLUtils.getobs(table, 1) |> collect |> scitype + + +# # MEASURE SEARCH BASED ON ARGUMENTS + +StatisticalMeasures.measures(y; kwargs...) = filter(measures(; kwargs...)) do (_, metadata) + guess_observation_scitype(y) <: metadata.observation_scitype +end + +function StatisticalMeasures.measures(yhat, y; trait_filters...) + y_scitype = guess_observation_scitype(y) + yhat_scitype = guess_observation_scitype(yhat) + filter(measures(; trait_filters...)) do (_, metadata) + requirement1 = y_scitype <: metadata.observation_scitype + proxy = metadata.kind_of_proxy + requirement2 = if proxy == LearnAPI.LiteralTarget() + yhat_scitype <: metadata.observation_scitype + elseif proxy == LearnAPI.Distribution() + yhat_scitype <: Density{<:y_scitype} + else + false + end + requirement3 = !Tables.istable(y) || metadata.can_consume_tables + requirement1 && requirement2 && requirement3 + end +end + +end # module diff --git a/src/StatisticalMeasures.jl b/src/StatisticalMeasures.jl index 6b8e131..c530165 100644 --- a/src/StatisticalMeasures.jl +++ b/src/StatisticalMeasures.jl @@ -43,7 +43,8 @@ include("precompile.jl") # remove after julia LTS supports pkg extensions: if !isdefined(Base, :get_extension) - include("../ext/LossFunctionsExt.jl") + include("../ext/LossFunctionsExt.jl") + include("../ext/ScientificTypesExt.jl") end const MEASURES_FOR_EXPORT = let measures = measures() diff --git a/src/registry.jl b/src/registry.jl index 1b03b19..70b7de1 100644 --- a/src/registry.jl +++ b/src/registry.jl @@ -10,7 +10,7 @@ const ERR_BAD_KWARG(trait) = ArgumentError( ) """ - measures(; filter_options...) + measures(; trait_options...) *Experimental* and subject to breaking behavior between patch releases. @@ -19,13 +19,13 @@ StatisticalMeasures.jl. The value of `dict[constructor]` provides information ab (measure "metadata") shared by all measures constructed using the syntax `constructor(args...)`. -# Filter options +# Trait options One can filter on the basis of measure trait values, as shown in this example: ``` using StatisticalMeasures -using ScientificTypes +import ScientificTypesBase.Multiclass julia> measures( observation_scitype = Union{Missing,Multiclass}, @@ -33,12 +33,36 @@ julia> measures( ) ``` -For more general searches, use a `filter(measures()) do (_, metadata) ... end` -construction. +--- + + measures(y; trait_filters...) + measures(yhat, y; trait_filters...) + +*Experimental* and subject to breaking behavior between patch releases. + +Assuming, ScientificTypes.jl has been imported, find measures that can be applied to data +with the specified data arguments `(y,)` or `(yhat, y)`. It is assumed that the arguments +contain multiple observations (have types implementing `MLUtils.getobs`). + +Returns a dictionary keyed on the constructors of such measures. Additional +`trait_filters` are the same as for the zero argument `measures` method. + +```julia +using StatisticalArrays +using ScientificTypes + +julia> measures(rand(3), rand(3), supports_weights=false) +LittleDict{Any, Any, Vector{Any}, Vector{Any}} with 1 entry: + RSquared => (aliases = ("rsq", "rsquared"), consumes_multiple_observations = true, can_re… +``` + +*Warning.* Matching is based only on the *first* observation of the arguments provided, +and must be interpreted carefully if, for example, `y` or `yhat` are vectors with `Union` +or other abstract element types. """ -measures(; kwargs...) = filter(TRAITS_GIVEN_CONSTRUCTOR) do (_, metadata) - trait_value_pairs = collect(kwargs) +measures(; trait_options...) = filter(TRAITS_GIVEN_CONSTRUCTOR) do (_, metadata) + trait_value_pairs = collect(trait_options) traits = first.(trait_value_pairs) for trait in traits trait in API.OVERLOADABLE_TRAITS || throw(ERR_BAD_KWARG(trait)) @@ -50,37 +74,27 @@ measures(; kwargs...) = filter(TRAITS_GIVEN_CONSTRUCTOR) do (_, metadata) end end - """ - measures(needle::Union{AbstractString,Regex}) + measures(needle::Union{AbstractString,Regex}; trait_options...) *Experimental* and subject to breaking behavior between patch releases. -Return a dictionary keyed on measure constructors that contain `needle` in their document -strings. +Find measures that contain `needle` in their document string. Returns a dictionary keyed +on the constructors of such measures. ``` -julia> measures("root") -LittleDict{Any, Any, Vector{Any}, Vector{Any}} with 8 entries: - RootMeanSquaredError => (aliases = ("rms", "rmse", "root_mean_squared_error"), c… - MultitargetRootMeanSquaredEr… => (aliases = ("multitarget_rms", "multitarget_rmse", "mult… - RootMeanSquaredLogError => (aliases = ("rmsl", "rmsle", "root_mean_squared_log_erro… - MultitargetRootMeanSquaredLo… => (aliases = ("multitarget_rmsl", "multitarget_rmsle", "mu… - RootMeanSquaredLogProportion… => (aliases = ("rmslp1",), consumes_multiple_observations =… - MultitargetRootMeanSquaredLo… => (aliases = ("multitarget_rmslp1",), consumes_multiple_ob… - RootMeanSquaredProportionalE… => (aliases = ("rmsp",), consumes_multiple_observations = t… - MultitargetRootMeanSquaredPr… => (aliases = ("multitarget_rmsp",), consumes_multiple_obse… +julia> measures("Matthew") +LittleDict{Any, Any, Vector{Any}, Vector{Any}} with 1 entry: + MatthewsCorrelation => (aliases = ("matthews_correlation", "mcc"), consumes_multiple_obse… ``` - """ -function measures(needle::Union{AbstractString,Regex}; kwargs...) - filter(measures(; kwargs...)) do (constructor, _) +function measures(needle::Union{AbstractString,Regex}; trait_options...) + filter(measures(; trait_options...)) do (constructor, _) doc = Base.Docs.doc(constructor) |> string occursin(needle, doc) end end - """ StatisticalMeasures.register(constructor, aliases=String[]) diff --git a/src/tools.jl b/src/tools.jl index e4a5d71..2bfdda8 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -54,3 +54,4 @@ function API.check_pools( end return nothing end + diff --git a/test/ScientificTypesExt.jl b/test/ScientificTypesExt.jl new file mode 100644 index 0000000..8067ebd --- /dev/null +++ b/test/ScientificTypesExt.jl @@ -0,0 +1,29 @@ +using ScientificTypes +using ScientificTypes.Tables + +n = 10 +p = 3 +y = rand(p, n) +yhat = rand(p, n) +t = y' |> Tables.table |> Tables.columntable +that = yhat' |> Tables.table |> Tables.columntable + +ms = measures(t) +@test all(ms) do (_, metadata) + AbstractVector{Continuous} <: metadata.observation_scitype +end + +ms = measures(that, t) +@test all(ms) do (_, metadata) + AbstractVector{Continuous} <: metadata.observation_scitype +end + +y = categorical(rand("ab", n)) +yhat = UnivariateFinite(levels(y), rand(n), augment=true, pool=y) +ms2 = measures((yhat, y)) +@test all(ms2) do (_, metadata) + Multiclass{2} <: metadata.observation_scitype && + metadata.kind_of_proxy == LearnAPI.Distribution() +end + +@test isempty(intersect(keys(ms), keys(ms2))) diff --git a/test/runtests.jl b/test/runtests.jl index e1ffc39..2d59f9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,3 +47,7 @@ end @testset "LossFunctionsExt.jl" begin include("LossFunctionsExt.jl") end + +@testset "ScientificTypesExt.jl" begin + include("ScientificTypesExt.jl") +end From e1538a94812d46f7b3e8811712495682bd43182b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 25 May 2023 09:16:39 +1200 Subject: [PATCH 14/15] fix failure to stop precision export --- src/StatisticalMeasures.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/StatisticalMeasures.jl b/src/StatisticalMeasures.jl index c530165..06c2e8a 100644 --- a/src/StatisticalMeasures.jl +++ b/src/StatisticalMeasures.jl @@ -50,9 +50,9 @@ end const MEASURES_FOR_EXPORT = let measures = measures() ret = Symbol[] for C in keys(measures) - C === :precision && continue push!(ret, Symbol(C)) for alias in measures[C].aliases + alias == "precision" && continue push!(ret, Symbol(alias)) end end From 87a9d66c37671fcf4be30ab998acd233fc9d689b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 25 May 2023 11:53:06 +1200 Subject: [PATCH 15/15] fix an issue with docstring generation in Functions module --- src/functions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/functions.jl b/src/functions.jl index 5e99630..fc658b4 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -163,7 +163,7 @@ end # # FUNCTIONS ON MATRICES INTERPRETED AS CONFUSION MATRICES -clean(s) = join(split(s, "_"), " ") +clean(s) = join(split(last(split(s, ".")), "_"), " ") function docstring(measure; sig="(m)", name=clean(measure), binary=false, the=false) footer = binary ? "The first index corresponds to the \"negative\" class, "*