From f4024daaffdc4f174512405416ccb1004a4728e2 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 21 Oct 2025 12:38:25 +1300 Subject: [PATCH 1/4] add evaluate support for density estimators --- Project.toml | 2 + src/MLJBase.jl | 2 + src/resampling.jl | 11 +++++- test/interface/model_api.jl | 55 --------------------------- test/resampling.jl | 76 +++++++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 56 deletions(-) diff --git a/Project.toml b/Project.toml index 23912096..537e439c 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb" @@ -44,6 +45,7 @@ CategoricalDistributions = "0.1" ComputationalResources = "0.3" DelimitedFiles = "1" Distributions = "0.25.3" +FillArrays = "1.14.0" InvertedIndices = "1" LearnAPI = "1" MLJModelInterface = "1.11" diff --git a/src/MLJBase.jl b/src/MLJBase.jl index 03dff2a1..eba90059 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -91,6 +91,8 @@ const Dist = Distributions # Measures import StatisticalMeasuresBase +import FillArrays + # Plots using RecipesBase: RecipesBase, @recipe diff --git a/src/resampling.jl b/src/resampling.jl index b6266b20..604200e8 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -849,6 +849,10 @@ end # --------------------------------------------------------------- # Helpers +# to fill out predictions in the case of density estimation ("cone" construction): +fill_if_needed(yhat, X, n) = yhat +fill_if_needed(yhat, X::Nothing, n) = FillArrays.Fill(yhat, n) + function actual_rows(rows, N, verbosity) unspecified_rows = (rows === nothing) _rows = unspecified_rows ? (1:N) : rows @@ -1470,10 +1474,15 @@ function evaluate!( function fit_and_extract_on_fold(mach, k) train, test = resampling[k] fit!(mach; rows=train, verbosity=verbosity - 1, force=force) + ntest = MLJBase.nrows(test) # build a dictionary of predictions keyed on the operations # that appear (`predict`, `predict_mode`, etc): yhat_given_operation = - Dict(op=>op(mach, rows=test) for op in unique(operations)) + Dict(op=> + fill_if_needed(op(mach, rows=test), X, ntest) + for op in unique(operations)) + # Note: `fill_if_need(yhat, X, n) = yhat` in typical case that `X` is different + # from `nothing`. ytest = selectrows(y, test) if per_observation_flag diff --git a/test/interface/model_api.jl b/test/interface/model_api.jl index 8966f70f..ed338a68 100644 --- a/test/interface/model_api.jl +++ b/test/interface/model_api.jl @@ -2,10 +2,7 @@ module TestModelAPI using Test using MLJBase -using StatisticalMeasures -import MLJModelInterface using ..Models -using Distributions using StableRNGs rng = StableRNG(661) @@ -30,57 +27,5 @@ rng = StableRNG(661) @test predict_mode(rgs, fitresult, X)[1] == 3 end -mutable struct UnivariateFiniteFitter <: MLJModelInterface.Probabilistic - alpha::Float64 -end -UnivariateFiniteFitter(;alpha=1.0) = UnivariateFiniteFitter(alpha) - -@testset "models that fit a distribution" begin - function MLJModelInterface.fit(model::UnivariateFiniteFitter, - verbosity, X, y) - - α = model.alpha - N = length(y) - _classes = classes(y) - d = length(_classes) - - frequency_given_class = Distributions.countmap(y) - prob_given_class = - Dict(c => (frequency_given_class[c] + α)/(N + α*d) for c in _classes) - - fitresult = MLJBase.UnivariateFinite(prob_given_class) - - report = (params=Distributions.params(fitresult),) - cache = nothing - - verbosity > 0 && @info "Fitted a $fitresult" - - return fitresult, cache, report - end - - MLJModelInterface.predict(model::UnivariateFiniteFitter, - fitresult, - X) = fitresult - - - MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) = - Nothing - MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) = - AbstractVector{<:Finite} - - y = coerce(collect("aabbccaa"), Multiclass) - X = nothing - model = UnivariateFiniteFitter(alpha=0) - mach = machine(model, X, y) - fit!(mach, verbosity=0) - - ytest = y[1:3] - yhat = predict(mach, nothing) # single UnivariateFinite distribution - - @test cross_entropy(fill(yhat, 3), ytest) ≈ - mean([-log(1/2), -log(1/2), -log(1/4)]) - -end - end true diff --git a/test/resampling.jl b/test/resampling.jl index 688147e1..690ae8a9 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -6,6 +6,8 @@ import Tables @everywhere import StatisticalMeasures.StatisticalMeasuresBase as API using StatisticalMeasures import LearnAPI +import CategoricalDistributions +import MLJModelInterface @everywhere begin using .Models @@ -1001,4 +1003,78 @@ MLJBase.save(logger::DummyLogger, mach::Machine) = mach.model @test MLJBase.save(mach) == model end + +# # RESAMPLING FOR DENSITY ESTIMATORS + +# we define a density estimator to fit a `UnivariateFinite` distribution to some +# Categorical data, with a Laplace smoothing option, α. + +mutable struct UnivariateFiniteFitter <: MLJModelInterface.Probabilistic + alpha::Float64 +end +UnivariateFiniteFitter(;alpha=1.0) = UnivariateFiniteFitter(alpha) + +function MLJModelInterface.fit(model::UnivariateFiniteFitter, + verbosity, X, y) + + α = model.alpha + N = length(y) + _classes = classes(y) + d = length(_classes) + + frequency_given_class = Distributions.countmap(y) + prob_given_class = + Dict(c => (get(frequency_given_class, c, 0) + α)/(N + α*d) for c in _classes) + + fitresult = CategoricalDistributions.UnivariateFinite(prob_given_class) + + report = (params=Distributions.params(fitresult),) + cache = nothing + + verbosity > 0 && @info "Fitted a $fitresult" + + return fitresult, cache, report +end + +MLJModelInterface.predict(model::UnivariateFiniteFitter, + fitresult, + X) = fitresult + + +MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) = + Nothing +MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) = + AbstractVector{<:Finite} + +@testset "resampling for density estimators" begin + y = coerce(collect("aabbccaa"), Multiclass) + X = nothing + + train, test = partition(eachindex(y), 0.8) + + # this model type is defined in /src/interface/model_api + model = UnivariateFiniteFitter(alpha=0) + + mach = machine(model, X, y) + fit!(mach, rows=train, verbosity=0) + + ytest = y[test] + yhat = predict(mach, nothing) # single UnivariateFinite distribution + + # Estmiate out-of-sample loss. Notice we have to make duplicate versions `yhat`, to + # match the number ground truth observations with which we are pairing it ("cone" + # construction): + by_hand = log_loss(fill(yhat, length(ytest)), ytest) + + # test some behaviour on which the implementation of `evaluate` for density estimators + # is predicated: + @test isnothing(selectrows(X, 1:3)) + @test predict(mach, rows=1:3) ≈ yhat + + # evaluate has an internal "cone" construction when `X = nothing`, so this should just + # work: + e = evaluate(model, X, y, resampling=[(train, test)], measure=log_loss) + @test e.measurement[1] ≈ by_hand +end + true From 398c07ce883684300bacf540e6f44602a4989a8d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 21 Oct 2025 13:02:33 +1300 Subject: [PATCH 2/4] improve test --- test/resampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/resampling.jl b/test/resampling.jl index 690ae8a9..7a6364f4 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -1047,7 +1047,7 @@ MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) = AbstractVector{<:Finite} @testset "resampling for density estimators" begin - y = coerce(collect("aabbccaa"), Multiclass) + y = coerce(rand(StableRNG(123), "abc", 20), Multiclass) X = nothing train, test = partition(eachindex(y), 0.8) From f18269dbabb0c0238496f196327a0c302267beb9 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 3 Nov 2025 10:42:04 +1300 Subject: [PATCH 3/4] typos --- test/resampling.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/resampling.jl b/test/resampling.jl index 7a6364f4..fd8f8a3f 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -1061,12 +1061,12 @@ MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) = ytest = y[test] yhat = predict(mach, nothing) # single UnivariateFinite distribution - # Estmiate out-of-sample loss. Notice we have to make duplicate versions `yhat`, to + # Estimate out-of-sample loss. Notice we have to make duplicate versions `yhat`, to # match the number ground truth observations with which we are pairing it ("cone" # construction): by_hand = log_loss(fill(yhat, length(ytest)), ytest) - # test some behaviour on which the implementation of `evaluate` for density estimators + # test some behavior on which the implementation of `evaluate` for density estimators # is predicated: @test isnothing(selectrows(X, 1:3)) @test predict(mach, rows=1:3) ≈ yhat From 9d6a86f27e24c96a380b4d198bd42040af8144c3 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 5 Dec 2025 06:53:23 +1300 Subject: [PATCH 4/4] remove out-dated code comment --- test/resampling.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/resampling.jl b/test/resampling.jl index fd8f8a3f..04ebe3ea 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -1052,7 +1052,6 @@ MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) = train, test = partition(eachindex(y), 0.8) - # this model type is defined in /src/interface/model_api model = UnivariateFiniteFitter(alpha=0) mach = machine(model, X, y)