diff --git a/Project.toml b/Project.toml index e9475460..66747872 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb" @@ -44,6 +45,7 @@ CategoricalDistributions = "0.2" ComputationalResources = "0.3" DelimitedFiles = "1" Distributions = "0.25.3" +FillArrays = "1.14.0" InvertedIndices = "1" LearnAPI = "2" MLJModelInterface = "1.11" diff --git a/src/MLJBase.jl b/src/MLJBase.jl index 03dff2a1..eba90059 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -91,6 +91,8 @@ const Dist = Distributions # Measures import StatisticalMeasuresBase +import FillArrays + # Plots using RecipesBase: RecipesBase, @recipe diff --git a/src/resampling.jl b/src/resampling.jl index b6266b20..604200e8 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -849,6 +849,10 @@ end # --------------------------------------------------------------- # Helpers +# to fill out predictions in the case of density estimation ("cone" construction): +fill_if_needed(yhat, X, n) = yhat +fill_if_needed(yhat, X::Nothing, n) = FillArrays.Fill(yhat, n) + function actual_rows(rows, N, verbosity) unspecified_rows = (rows === nothing) _rows = unspecified_rows ? (1:N) : rows @@ -1470,10 +1474,15 @@ function evaluate!( function fit_and_extract_on_fold(mach, k) train, test = resampling[k] fit!(mach; rows=train, verbosity=verbosity - 1, force=force) + ntest = MLJBase.nrows(test) # build a dictionary of predictions keyed on the operations # that appear (`predict`, `predict_mode`, etc): yhat_given_operation = - Dict(op=>op(mach, rows=test) for op in unique(operations)) + Dict(op=> + fill_if_needed(op(mach, rows=test), X, ntest) + for op in unique(operations)) + # Note: `fill_if_need(yhat, X, n) = yhat` in typical case that `X` is different + # from `nothing`. ytest = selectrows(y, test) if per_observation_flag diff --git a/test/interface/model_api.jl b/test/interface/model_api.jl index 8966f70f..ed338a68 100644 --- a/test/interface/model_api.jl +++ b/test/interface/model_api.jl @@ -2,10 +2,7 @@ module TestModelAPI using Test using MLJBase -using StatisticalMeasures -import MLJModelInterface using ..Models -using Distributions using StableRNGs rng = StableRNG(661) @@ -30,57 +27,5 @@ rng = StableRNG(661) @test predict_mode(rgs, fitresult, X)[1] == 3 end -mutable struct UnivariateFiniteFitter <: MLJModelInterface.Probabilistic - alpha::Float64 -end -UnivariateFiniteFitter(;alpha=1.0) = UnivariateFiniteFitter(alpha) - -@testset "models that fit a distribution" begin - function MLJModelInterface.fit(model::UnivariateFiniteFitter, - verbosity, X, y) - - α = model.alpha - N = length(y) - _classes = classes(y) - d = length(_classes) - - frequency_given_class = Distributions.countmap(y) - prob_given_class = - Dict(c => (frequency_given_class[c] + α)/(N + α*d) for c in _classes) - - fitresult = MLJBase.UnivariateFinite(prob_given_class) - - report = (params=Distributions.params(fitresult),) - cache = nothing - - verbosity > 0 && @info "Fitted a $fitresult" - - return fitresult, cache, report - end - - MLJModelInterface.predict(model::UnivariateFiniteFitter, - fitresult, - X) = fitresult - - - MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) = - Nothing - MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) = - AbstractVector{<:Finite} - - y = coerce(collect("aabbccaa"), Multiclass) - X = nothing - model = UnivariateFiniteFitter(alpha=0) - mach = machine(model, X, y) - fit!(mach, verbosity=0) - - ytest = y[1:3] - yhat = predict(mach, nothing) # single UnivariateFinite distribution - - @test cross_entropy(fill(yhat, 3), ytest) ≈ - mean([-log(1/2), -log(1/2), -log(1/4)]) - -end - end true diff --git a/test/resampling.jl b/test/resampling.jl index 688147e1..04ebe3ea 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -6,6 +6,8 @@ import Tables @everywhere import StatisticalMeasures.StatisticalMeasuresBase as API using StatisticalMeasures import LearnAPI +import CategoricalDistributions +import MLJModelInterface @everywhere begin using .Models @@ -1001,4 +1003,77 @@ MLJBase.save(logger::DummyLogger, mach::Machine) = mach.model @test MLJBase.save(mach) == model end + +# # RESAMPLING FOR DENSITY ESTIMATORS + +# we define a density estimator to fit a `UnivariateFinite` distribution to some +# Categorical data, with a Laplace smoothing option, α. + +mutable struct UnivariateFiniteFitter <: MLJModelInterface.Probabilistic + alpha::Float64 +end +UnivariateFiniteFitter(;alpha=1.0) = UnivariateFiniteFitter(alpha) + +function MLJModelInterface.fit(model::UnivariateFiniteFitter, + verbosity, X, y) + + α = model.alpha + N = length(y) + _classes = classes(y) + d = length(_classes) + + frequency_given_class = Distributions.countmap(y) + prob_given_class = + Dict(c => (get(frequency_given_class, c, 0) + α)/(N + α*d) for c in _classes) + + fitresult = CategoricalDistributions.UnivariateFinite(prob_given_class) + + report = (params=Distributions.params(fitresult),) + cache = nothing + + verbosity > 0 && @info "Fitted a $fitresult" + + return fitresult, cache, report +end + +MLJModelInterface.predict(model::UnivariateFiniteFitter, + fitresult, + X) = fitresult + + +MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) = + Nothing +MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) = + AbstractVector{<:Finite} + +@testset "resampling for density estimators" begin + y = coerce(rand(StableRNG(123), "abc", 20), Multiclass) + X = nothing + + train, test = partition(eachindex(y), 0.8) + + model = UnivariateFiniteFitter(alpha=0) + + mach = machine(model, X, y) + fit!(mach, rows=train, verbosity=0) + + ytest = y[test] + yhat = predict(mach, nothing) # single UnivariateFinite distribution + + # Estimate out-of-sample loss. Notice we have to make duplicate versions `yhat`, to + # match the number ground truth observations with which we are pairing it ("cone" + # construction): + by_hand = log_loss(fill(yhat, length(ytest)), ytest) + + # test some behavior on which the implementation of `evaluate` for density estimators + # is predicated: + @test isnothing(selectrows(X, 1:3)) + @test predict(mach, rows=1:3) ≈ yhat + + # evaluate has an internal "cone" construction when `X = nothing`, so this should just + # work: + e = evaluate(model, X, y, resampling=[(train, test)], measure=log_loss) + @test e.measurement[1] ≈ by_hand +end + true