Skip to content

Commit

Permalink
Merge pull request #3 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.1.1 release
  • Loading branch information
ablaom authored May 25, 2023
2 parents 780c7ec + 87a9d66 commit be112a3
Show file tree
Hide file tree
Showing 21 changed files with 230 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Manifest.toml
docs/src/_auto_generated_list_of_measures.md
docs/src/auto_generated_list_of_measures.md
.ipynb_checkpoints
*~
#*
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StatisticalMeasures"
uuid = "a19d573c-0a75-4610-95b3-7071388c7541"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -13,16 +13,19 @@ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"

[extensions]
LossFunctionsExt = "LossFunctions"
ScientificTypesExt = "ScientificTypes"

[compat]
CategoricalArrays = "0.10"
Expand All @@ -33,6 +36,7 @@ LossFunctions = "0.10"
MacroTools = "0.5"
OrderedCollections = "1"
PrecompileTools = "1.1"
ScientificTypes = "3"
ScientificTypesBase = "3"
StatisticalMeasuresBase = "0.1"
StatsBase = "0.33, 0.34"
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# StatisticalMeasures.jl

&#128679;

Measures (metrics) for statistics and machine learning.

[![Build Status](https://github.com/JuliaAI/StatisticalMeasures.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/StatisticalMeasures.jl/actions)
[![Coverage](https://codecov.io/gh/JuliaAI/StatisticalMeasures.jl/branch/master/graph/badge.svg)](https://codecov.io/github/JuliaAI/StatisticalMeasures.jl?branch=master)
[![Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliaai.github.io/StatisticalMeasures.jl/dev/)

[List](https://juliaai.github.io/StatisticalMeasures.jl/dev/_auto_generated_list_of_measures.html#aliases) of provided measures and their aliases thereof.
[List](https://juliaai.github.io/StatisticalMeasures.jl/dev/auto_generated_list_of_measures#aliases) of provided measures and their aliases thereof.

Powered by
[StatisticalMeasuresBase.jl](https://juliaai.github.io/StatisticalMeasuresBase.jl/dev/).
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using StatisticalMeasures
using StatisticalMeasures.StatisticalMeasuresBase
using StatisticalMeasures.LearnAPI
using ScientificTypesBase
using ScientificTypes

const REPO="github.com/JuliaAI/StatisticalMeasures.jl"

Expand All @@ -20,7 +21,7 @@ makedocs(;
pages=[
"Overview" => "index.md",
"Examples of usage" => "examples_of_usage.md",
"The Measures" => "_auto_generated_list_of_measures.md",
"The Measures" => "auto_generated_list_of_measures.md",
"Confusion Matrices" => "confusion_matrices.md",
"Receiver Operator Characteristics" => "roc.md",
"Tools" => "tools.md",
Expand Down
2 changes: 1 addition & 1 deletion docs/make_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function alias_table()
end

function write_measures_page(path=PATH_TO_DOCS_SRC)
pagename = "_auto_generated_list_of_measures.md"
pagename = "auto_generated_list_of_measures.md"
pagepath = joinpath(path, pagename)
traits_given_constructor = measures()
all_constructors = keys(traits_given_constructor) |> collect
Expand Down
22 changes: 21 additions & 1 deletion docs/src/examples_of_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bacc == BalancedAccuracy() == BalancedAccuracy(adjusted=false)
- [Probabilistic regression](@ref)
- [Custom multi-target measures](@ref)
- [Using losses from LossFunctions.jl](@ref)

- [Measure search (experimental feature)](@ref)

## Binary classification

Expand Down Expand Up @@ -333,3 +333,23 @@ Wrap again, as shown in the preceding section, to get a multi-target version.

For distance-based loss functions, wrapping in `Measure` is not strictly
necessary, but does no harm.


## Measure search (experimental feature)

```@example 29
using StatisticalMeasures
using ScientificTypes
y = rand(3)
yhat = rand(3)
options = measures(yhat, y, supports_weights=true)
```

```@example 29
options[LPLoss]
```

```@example 29
measures("Matthew")
```
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<script async defer src="https://buttons.github.io/buttons.js"></script>
<div style="font-size:1.4em;font-weight:bold;">
<a href="https://juliaai.github.io/StatisticalMeasures.jl/dev/_auto_generated_list_of_measures.html#aliases"
<a href="https://juliaai.github.io/StatisticalMeasures.jl/dev/auto_generated_list_of_measures#aliases"
style="color: #9558B2;">List of measures</a> &nbsp;|&nbsp;
<a href="examples_of_usage"
style="color: #389826;">Examples</a>
Expand All @@ -19,7 +19,7 @@ Measures (metrics) for statistics and machine learning</span>
# Overview

This package defines common measures (metrics) for classification and regression problems
in statistics and machine learning. To see if your favorite measures is implemented, see
in statistics and machine learning. To see if your favorite measure is implemented, see
[this list](@ref aliases). Some multi-target measures are included, but see also [Custom
multi-target measures](@ref).

Expand Down
2 changes: 1 addition & 1 deletion docs/src/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
| method | description |
|:----------------------------------------------|:-------------------------------------------------------------------------------------|
| [`measurements`](@ref)`(measure, ...)` | for obtaining per-observation measurements, instead of aggregated ones |
| [`measures()`](@ref) | dictionary of traits keyed on measure constructors |
| [`measures()`](@ref) | dictionary of traits keyed on measure constructors, with filter options |
| [`unfussy(measure)`](@ref) | new measure without argument checks¹ |
| [`multimeasure`](@ref)`(measure; options...)` | wrapper to broadcast measures over multiple observations |
| [`robust_measure(measure)`](@ref) | wrapper to silently treat unsupported weights as uniform |
Expand Down
2 changes: 0 additions & 2 deletions ext/LossFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ loss(measure) = measure
# show(io, mime, loss(measure))
# Base.show(io::IO, measure::LossFunctionType) = show(io, loss(measure))

println("############### LOADED ###############")

# # DISTANCE LOSS TYPE

@trait(
Expand Down
41 changes: 41 additions & 0 deletions ext/ScientificTypesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
module ScientificTypesExt
using ScientificTypes
import ScientificTypes.Tables
using StatisticalMeasures
import StatisticalMeasuresBase.MLUtils
import Distributions
import LearnAPI

# # HELPERS

guess_observation_scitype(y) = guess_observation_scitype(y, Val(Tables.istable(y)))
guess_observation_scitype(y, ::Val{false}) = MLUtils.getobs(y, 1) |> scitype
guess_observation_scitype(table, ::Val{true}) =
MLUtils.getobs(table, 1) |> collect |> scitype


# # MEASURE SEARCH BASED ON ARGUMENTS

StatisticalMeasures.measures(y; kwargs...) = filter(measures(; kwargs...)) do (_, metadata)
guess_observation_scitype(y) <: metadata.observation_scitype
end

function StatisticalMeasures.measures(yhat, y; trait_filters...)
y_scitype = guess_observation_scitype(y)
yhat_scitype = guess_observation_scitype(yhat)
filter(measures(; trait_filters...)) do (_, metadata)
requirement1 = y_scitype <: metadata.observation_scitype
proxy = metadata.kind_of_proxy
requirement2 = if proxy == LearnAPI.LiteralTarget()
yhat_scitype <: metadata.observation_scitype
elseif proxy == LearnAPI.Distribution()
yhat_scitype <: Density{<:y_scitype}
else
false
end
requirement3 = !Tables.istable(y) || metadata.can_consume_tables
requirement1 && requirement2 && requirement3
end
end

end # module
4 changes: 3 additions & 1 deletion src/StatisticalMeasures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ include("precompile.jl")

# remove after julia LTS supports pkg extensions:
if !isdefined(Base, :get_extension)
include("../ext/LossFunctionsExt.jl")
include("../ext/LossFunctionsExt.jl")
include("../ext/ScientificTypesExt.jl")
end

const MEASURES_FOR_EXPORT = let measures = measures()
ret = Symbol[]
for C in keys(measures)
push!(ret, Symbol(C))
for alias in measures[C].aliases
alias == "precision" && continue
push!(ret, Symbol(alias))
end
end
Expand Down
4 changes: 4 additions & 0 deletions src/confusion_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrix{N}
write(stream, take!(iob))
end

function Base.show(stream::IO, cm::ConfusionMatrix{N}) where N
mat = matrix(cm, warn=false)
print(stream, "ConfusionMatrix{$N}($(repr(mat)))")
end

# ## STATISTICAL FUNCTIONS ON CONFUSION MATRICES

Expand Down
2 changes: 1 addition & 1 deletion src/docstrings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function docstring(signature; body="", footer="", scitype="")
API.can_report_unaggregated(m) &&
(ret *= "\n\nMeasurements are aggregated. "*
"To obtain a separate measurement for each observation, "*
"use the syntax `measurements(m, ŷ, y)`. ")
"use the syntax `measurements($m_str, ŷ, y)`. ")
ret *= "Generally, an observation `obs` in `MLUtils.eachobs(y)` is expected to satisfy "*
"`ScientificTypes.scitype(obs)<:`$scitype. "
# if kind_of_proxy == LearnAPI.LiteralTarget()
Expand Down
2 changes: 1 addition & 1 deletion src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ end

# # FUNCTIONS ON MATRICES INTERPRETED AS CONFUSION MATRICES

clean(s) = join(split(s, "_"), " ")
clean(s) = join(split(last(split(s, ".")), "_"), " ")
function docstring(measure; sig="(m)", name=clean(measure), binary=false, the=false)
footer = binary ?
"The first index corresponds to the \"negative\" class, "*
Expand Down
3 changes: 2 additions & 1 deletion src/probabilistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ Core implementation: [`Functions.auc`](@ref).
$INVARIANT_LABEL
""",
scitype = ""
scitype = "",
footer="See also [`roc_curve`](@ref). ",
)

"$AreaUnderCurveDoc"
Expand Down
87 changes: 83 additions & 4 deletions src/registry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,96 @@ const ERR_BAD_CONSTRUCTOR = ArgumentError(
"Constructor must have a zero argument method. "
)

const ERR_BAD_KWARG(trait) = ArgumentError(
"`$trait` is not a valid measure trait name. You must choose one of these: "*
"$(API.OVERLOADABLE_TRAITS_LIST). "
)

"""
StatisticalMeasusures.measures()
measures(; trait_options...)
*Experimental* and subject to breaking behavior between patch releases.
Return a dictionary, `dict`, keyed on measure constructors provided by
StatisticalMeasures.jl. The value of `dict[constructor]` provides information about
traits shared by all measures constructed using the syntax `constructor(args...)`.
StatisticalMeasures.jl. The value of `dict[constructor]` provides information about traits
(measure "metadata") shared by all measures constructed using the syntax
`constructor(args...)`.
# Trait options
One can filter on the basis of measure trait values, as shown in this example:
```
using StatisticalMeasures
import ScientificTypesBase.Multiclass
julia> measures(
observation_scitype = Union{Missing,Multiclass},
supports_class_weights = true,
)
```
---
measures(y; trait_filters...)
measures(yhat, y; trait_filters...)
*Experimental* and subject to breaking behavior between patch releases.
Assuming, ScientificTypes.jl has been imported, find measures that can be applied to data
with the specified data arguments `(y,)` or `(yhat, y)`. It is assumed that the arguments
contain multiple observations (have types implementing `MLUtils.getobs`).
Returns a dictionary keyed on the constructors of such measures. Additional
`trait_filters` are the same as for the zero argument `measures` method.
```julia
using StatisticalArrays
using ScientificTypes
julia> measures(rand(3), rand(3), supports_weights=false)
LittleDict{Any, Any, Vector{Any}, Vector{Any}} with 1 entry:
RSquared => (aliases = ("rsq", "rsquared"), consumes_multiple_observations = true, can_re…
```
*Warning.* Matching is based only on the *first* observation of the arguments provided,
and must be interpreted carefully if, for example, `y` or `yhat` are vectors with `Union`
or other abstract element types.
"""
measures(; trait_options...) = filter(TRAITS_GIVEN_CONSTRUCTOR) do (_, metadata)
trait_value_pairs = collect(trait_options)
traits = first.(trait_value_pairs)
for trait in traits
trait in API.OVERLOADABLE_TRAITS || throw(ERR_BAD_KWARG(trait))
end
all(trait_value_pairs) do pair
trait = first(pair)
value = last(pair)
getproperty(metadata, trait) == value
end
end

"""
measures() = TRAITS_GIVEN_CONSTRUCTOR
measures(needle::Union{AbstractString,Regex}; trait_options...)
*Experimental* and subject to breaking behavior between patch releases.
Find measures that contain `needle` in their document string. Returns a dictionary keyed
on the constructors of such measures.
```
julia> measures("Matthew")
LittleDict{Any, Any, Vector{Any}, Vector{Any}} with 1 entry:
MatthewsCorrelation => (aliases = ("matthews_correlation", "mcc"), consumes_multiple_obse…
```
"""
function measures(needle::Union{AbstractString,Regex}; trait_options...)
filter(measures(; trait_options...)) do (constructor, _)
doc = Base.Docs.doc(constructor) |> string
occursin(needle, doc)
end
end

"""
StatisticalMeasures.register(constructor, aliases=String[])
Expand Down
1 change: 1 addition & 0 deletions src/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ function API.check_pools(
end
return nothing
end

29 changes: 29 additions & 0 deletions test/ScientificTypesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using ScientificTypes
using ScientificTypes.Tables

n = 10
p = 3
y = rand(p, n)
yhat = rand(p, n)
t = y' |> Tables.table |> Tables.columntable
that = yhat' |> Tables.table |> Tables.columntable

ms = measures(t)
@test all(ms) do (_, metadata)
AbstractVector{Continuous} <: metadata.observation_scitype
end

ms = measures(that, t)
@test all(ms) do (_, metadata)
AbstractVector{Continuous} <: metadata.observation_scitype
end

y = categorical(rand("ab", n))
yhat = UnivariateFinite(levels(y), rand(n), augment=true, pool=y)
ms2 = measures((yhat, y))
@test all(ms2) do (_, metadata)
Multiclass{2} <: metadata.observation_scitype &&
metadata.kind_of_proxy == LearnAPI.Distribution()
end

@test isempty(intersect(keys(ms), keys(ms2)))
Loading

0 comments on commit be112a3

Please sign in to comment.