Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .stats.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"lean_modules": 61,
"lean_lines": 32558,
"_note": "Generated by scripts/repo-stats.sh. Do not edit by hand."
}
1 change: 1 addition & 0 deletions FormalSLT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ import FormalSLT.PACBayesFiniteProductMGF
import FormalSLT.PACBayesBoundedLoss
import FormalSLT.PACBayesMcAllester
import FormalSLT.PACBayesBernstein
import FormalSLT.PACBayesMargin
343 changes: 343 additions & 0 deletions FormalSLT/PACBayesBernstein.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Released under MIT license as described in the file LICENSE.
Authors: Robby Sneiderman
-/
import FormalSLT.PACBayesKL
import FormalSLT.Probability.FiniteUnionBound
import Mathlib.Algebra.BigOperators.Field
import Mathlib.Algebra.Order.BigOperators.Group.Finset
import Mathlib.Analysis.SpecialFunctions.Sqrt
Expand Down Expand Up @@ -37,6 +38,8 @@ classifier margin loss.
Markov/confidence layer for a fixed Bernstein parameter.
* `finitePACBayesBernsteinMargin_badEventMass_le_delta` — posterior-dependent
margin-style wrapper under explicit complexity and penalty certificates.
* `finitePACBayesBernsteinGridOptimized_badEventMass_le_delta` — finite-grid
optimization wrapper over supplied Bernstein parameter buckets.

## Current boundaries

Expand Down Expand Up @@ -94,6 +97,114 @@ def expectedPriorBernsteinExpMoment [Fintype Ω] [Fintype ι]
∑ ω, ν ω *
priorBernsteinExpMoment π lambda scale riskFn empiricalRiskFn varianceProxy ω

/--
Expected normalized PAC-Bayes Bernstein prior moment from per-hypothesis MGF
budgets.

This is the algebraic bridge used before plugging in a concrete iid
concentration lemma: each hypothesis supplies the Bernstein exponential budget,
then the prior-weighted normalized moment has expectation at most one.
-/
theorem expectedPriorBernsteinExpMoment_le_one_of_mgf_bound
[Fintype Ω] [Fintype ι]
{π : ι → ℝ} (hπ : IsPMF π)
(ν : Ω → ℝ) (lambda scale : ℝ)
(riskFn : ι → ℝ) (empiricalRiskFn : Ω → ι → ℝ)
(varianceProxy : ι → ℝ)
(hmgf :
∀ i : ι,
(∑ ω : Ω, ν ω *
Real.exp (lambda * (riskFn i - empiricalRiskFn ω i))) ≤
Real.exp
(lambda ^ 2 * varianceProxy i /
(2 * (1 - scale * lambda)))) :
expectedPriorBernsteinExpMoment ν π lambda scale riskFn empiricalRiskFn
varianceProxy ≤ 1 := by
classical
let budget : ι → ℝ :=
fun i =>
lambda ^ 2 * varianceProxy i /
(2 * (1 - scale * lambda))
have hexp_split : ∀ a b : ℝ, Real.exp (a - b) = Real.exp a * Real.exp (-b) := by
intro a b
rw [← Real.exp_add]
ring_nf
have hswap :
expectedPriorBernsteinExpMoment ν π lambda scale riskFn empiricalRiskFn
varianceProxy =
∑ i : ι,
π i * Real.exp (-(budget i)) *
(∑ ω : Ω, ν ω *
Real.exp (lambda * (riskFn i - empiricalRiskFn ω i))) := by
unfold expectedPriorBernsteinExpMoment priorBernsteinExpMoment
calc
(∑ ω : Ω,
ν ω *
∑ i : ι,
π i *
Real.exp
(lambda * (riskFn i - empiricalRiskFn ω i) -
lambda ^ 2 * varianceProxy i /
(2 * (1 - scale * lambda))))
=
∑ ω : Ω, ∑ i : ι,
ν ω *
(π i *
Real.exp
(lambda * (riskFn i - empiricalRiskFn ω i) -
budget i)) := by
apply Finset.sum_congr rfl
intro ω _hω
rw [Finset.mul_sum]
_ =
∑ i : ι, ∑ ω : Ω,
ν ω *
(π i *
Real.exp
(lambda * (riskFn i - empiricalRiskFn ω i) -
budget i)) := by
rw [Finset.sum_comm]
_ =
∑ i : ι,
π i * Real.exp (-(budget i)) *
(∑ ω : Ω, ν ω *
Real.exp (lambda * (riskFn i - empiricalRiskFn ω i))) := by
apply Finset.sum_congr rfl
intro i _hi
rw [Finset.mul_sum]
apply Finset.sum_congr rfl
intro ω _hω
rw [hexp_split]
ring
rw [hswap]
calc
(∑ i : ι,
π i * Real.exp (-(budget i)) *
(∑ ω : Ω, ν ω *
Real.exp (lambda * (riskFn i - empiricalRiskFn ω i))))
∑ i : ι,
π i * Real.exp (-(budget i)) *
Real.exp (budget i) := by
apply Finset.sum_le_sum
intro i _hi
exact mul_le_mul_of_nonneg_left
(hmgf i)
(mul_nonneg (hπ.nonneg i) (le_of_lt (Real.exp_pos _)))
_ = ∑ i : ι, π i := by
apply Finset.sum_congr rfl
intro i _hi
calc
π i * Real.exp (-(budget i)) * Real.exp (budget i)
= π i * (Real.exp (-(budget i)) * Real.exp (budget i)) := by ring
_ = π i * 1 := by
congr 1
rw [← Real.exp_add]
ring_nf
simp
_ = π i := by ring
_ = 1 := hπ.sum_one

/-- Finite sample mass of outcomes whose Bernstein prior moment exceeds a threshold. -/
def priorBernsteinExpMomentTailMass [Fintype Ω] [Fintype ι]
(ν : Ω → ℝ) (π : ι → ℝ)
Expand Down Expand Up @@ -556,6 +667,238 @@ theorem finitePACBayesBernsteinMargin_badEventMass_le_delta
scale * complexityOf ρ)
hcomplexity hpenalty hExpected

/-! ### Finite-grid peeling over Bernstein parameters -/

/--
Samples where some posterior violates at least one finite Bernstein parameter
bucket.

Each grid index supplies a parameter `lambdaOf g`, a scale `scaleOf g`, and a
confidence allocation `confidenceOf g`. This is a finite-grid statement, not an
optimization over all real parameters.
-/
def finitePACBayesBernsteinGridBadSamples
[Fintype Ω] [Fintype ι] [Fintype γ]
(π : ι → ℝ) (lambdaOf scaleOf confidenceOf : γ → ℝ)
(riskFn : ι → ℝ) (empiricalRiskFn : Ω → ι → ℝ)
(varianceProxy : ι → ℝ) : Finset Ω :=
Finset.univ.filter fun ω =>
∃ g : γ, ∃ ρ : ι → ℝ,
IsPMF ρ ∧
posteriorGeneralizationGap ρ riskFn (empiricalRiskFn ω) >
(klDiv ρ π + Real.log (1 / confidenceOf g)) / lambdaOf g +
lambdaOf g * posteriorMarginVarianceProxy ρ varianceProxy /
(2 * (1 - scaleOf g * lambdaOf g))

/--
Finite-grid Bernstein bad-event bound.

For each finite grid bucket `g`, the fixed-`lambda` Bernstein theorem is applied
with its own confidence allocation. A finite weighted union bound controls the
mass of samples on which any bucket fails.
-/
theorem finitePACBayesBernsteinGrid_badEventMass_le_sum_delta
[Fintype Ω] [DecidableEq Ω] [Fintype ι] [Nonempty ι] [Fintype γ]
{ν : Ω → ℝ} (hν : IsPMF ν)
{π : ι → ℝ} (hπ : IsFullSupportPMF π)
(lambdaOf scaleOf confidenceOf : γ → ℝ)
(hlambda : ∀ g : γ, 0 < lambdaOf g)
(hscale : ∀ g : γ, scaleOf g * lambdaOf g < 1)
(hconfidenceOf : ∀ g : γ, 0 < confidenceOf g)
(riskFn : ι → ℝ) (empiricalRiskFn : Ω → ι → ℝ)
(varianceProxy : ι → ℝ)
(hExpected :
∀ g : γ,
expectedPriorBernsteinExpMoment ν π (lambdaOf g) (scaleOf g)
riskFn empiricalRiskFn varianceProxy ≤ 1) :
(∑ ω ∈
finitePACBayesBernsteinGridBadSamples π lambdaOf scaleOf confidenceOf
riskFn empiricalRiskFn varianceProxy,
ν ω) ≤
∑ g : γ, confidenceOf g := by
classical
let bucketEvent : γ → Finset Ω := fun g =>
finitePACBayesBernsteinFixedLambdaBadSamples π (lambdaOf g) (scaleOf g)
(confidenceOf g) riskFn empiricalRiskFn varianceProxy
let gridUnionEvent : Finset Ω :=
Finset.univ.filter fun ω : Ω => ∃ g : γ, ω ∈ bucketEvent g
have hsubset :
finitePACBayesBernsteinGridBadSamples π lambdaOf scaleOf confidenceOf
riskFn empiricalRiskFn varianceProxy
⊆ gridUnionEvent := by
intro ω hω
rw [finitePACBayesBernsteinGridBadSamples, Finset.mem_filter] at hω
rcases hω.2 with ⟨g, ρ, hρ, hbad⟩
change ω ∈ (Finset.univ.filter fun ω : Ω => ∃ g : γ, ω ∈ bucketEvent g)
rw [Finset.mem_filter]
refine ⟨Finset.mem_univ ω, g, ?_⟩
change ω ∈ finitePACBayesBernsteinFixedLambdaBadSamples π (lambdaOf g) (scaleOf g)
(confidenceOf g) riskFn empiricalRiskFn varianceProxy
rw [finitePACBayesBernsteinFixedLambdaBadSamples, Finset.mem_filter]
exact ⟨Finset.mem_univ ω, ρ, hρ, hbad⟩
have hmass_subset :
(∑ ω ∈
finitePACBayesBernsteinGridBadSamples π lambdaOf scaleOf confidenceOf
riskFn empiricalRiskFn varianceProxy,
ν ω)
≤ ∑ ω ∈ gridUnionEvent, ν ω := by
exact Finset.sum_le_sum_of_subset_of_nonneg hsubset (by
intro ω _hω _hnot
exact hν.nonneg ω)
have hunion :=
FormalSLT.Probability.FiniteUnionBound.finiteProbabilityUnionBound_proof
(support := (Finset.univ : Finset Ω))
(w := ν)
(events := bucketEvent)
(s := (Finset.univ : Finset γ))
(fun ω : Ω => hν.nonneg ω)
have hunion_mass :
(∑ ω ∈ gridUnionEvent, ν ω) ≤
∑ g : γ, ∑ ω ∈ bucketEvent g, ν ω := by
change
(∑ ω ∈ (Finset.univ.filter fun ω : Ω => ∃ g : γ, ω ∈ bucketEvent g), ν ω) ≤
∑ g : γ, ∑ ω ∈ bucketEvent g, ν ω
unfold FormalSLT.Probability.FiniteUnionBound.finiteUnionEventMass at hunion
unfold FormalSLT.Probability.FiniteUnionBound.finiteEventMassSum at hunion
unfold FormalSLT.Probability.FiniteUnionBound.finiteEventMass at hunion
simp_rw [← Finset.sum_filter] at hunion
simpa using hunion
have hbucket :
∀ g : γ, (∑ ω ∈ bucketEvent g, ν ω) ≤ confidenceOf g := by
intro g
simpa [bucketEvent] using
finitePACBayesBernstein_fixedLambda_badEventMass_le_delta
(ν := ν) hν hπ (lambdaOf g) (scaleOf g) (confidenceOf g)
(hlambda g) (hscale g) (hconfidenceOf g) riskFn empiricalRiskFn
varianceProxy (hExpected g)
have hbucket_sum :
(∑ g : γ, ∑ ω ∈ bucketEvent g, ν ω) ≤ ∑ g : γ, confidenceOf g := by
exact Finset.sum_le_sum (fun g _hg => hbucket g)
exact hmass_subset.trans (hunion_mass.trans hbucket_sum)

/-- Finite-grid Bernstein bad-event bound with an explicit confidence budget. -/
theorem finitePACBayesBernsteinGrid_badEventMass_le_delta
[Fintype Ω] [DecidableEq Ω] [Fintype ι] [Nonempty ι] [Fintype γ]
{ν : Ω → ℝ} (hν : IsPMF ν)
{π : ι → ℝ} (hπ : IsFullSupportPMF π)
(lambdaOf scaleOf confidenceOf : γ → ℝ)
(hlambda : ∀ g : γ, 0 < lambdaOf g)
(hscale : ∀ g : γ, scaleOf g * lambdaOf g < 1)
(hconfidenceOf : ∀ g : γ, 0 < confidenceOf g)
(riskFn : ι → ℝ) (empiricalRiskFn : Ω → ι → ℝ)
(varianceProxy : ι → ℝ)
(hExpected :
∀ g : γ,
expectedPriorBernsteinExpMoment ν π (lambdaOf g) (scaleOf g)
riskFn empiricalRiskFn varianceProxy ≤ 1)
{delta : ℝ} (hgridConfidence : (∑ g : γ, confidenceOf g) ≤ delta) :
(∑ ω ∈
finitePACBayesBernsteinGridBadSamples π lambdaOf scaleOf confidenceOf
riskFn empiricalRiskFn varianceProxy,
ν ω) ≤ delta :=
(finitePACBayesBernsteinGrid_badEventMass_le_sum_delta
(Ω := Ω) (ι := ι) (γ := γ) hν hπ lambdaOf scaleOf confidenceOf
hlambda hscale hconfidenceOf riskFn empiricalRiskFn varianceProxy hExpected).trans
hgridConfidence

/--
Finite-grid optimized Bernstein bad-event bound.

The user-supplied `posteriorPenalty` may depend on the posterior. The theorem
requires a finite grid certificate: each posterior is assigned to a bucket whose
fixed-parameter Bernstein penalty is no larger than the supplied penalty.
-/
theorem finitePACBayesBernsteinGridOptimized_badEventMass_le_sum_delta
[Fintype Ω] [DecidableEq Ω] [Fintype ι] [Nonempty ι] [Fintype γ]
{ν : Ω → ℝ} (hν : IsPMF ν)
{π : ι → ℝ} (hπ : IsFullSupportPMF π)
(lambdaOf scaleOf confidenceOf : γ → ℝ)
(posteriorPenalty : (ι → ℝ) → ℝ)
(hlambda : ∀ g : γ, 0 < lambdaOf g)
(hscale : ∀ g : γ, scaleOf g * lambdaOf g < 1)
(hconfidenceOf : ∀ g : γ, 0 < confidenceOf g)
(riskFn : ι → ℝ) (empiricalRiskFn : Ω → ι → ℝ)
(varianceProxy : ι → ℝ)
(hgridCovers :
∀ ρ : ι → ℝ, IsPMF ρ →
∃ g : γ,
(klDiv ρ π + Real.log (1 / confidenceOf g)) / lambdaOf g +
lambdaOf g * posteriorMarginVarianceProxy ρ varianceProxy /
(2 * (1 - scaleOf g * lambdaOf g))
≤ posteriorPenalty ρ)
(hExpected :
∀ g : γ,
expectedPriorBernsteinExpMoment ν π (lambdaOf g) (scaleOf g)
riskFn empiricalRiskFn varianceProxy ≤ 1) :
(∑ ω ∈
finitePACBayesBernsteinPenaltyBadSamples riskFn empiricalRiskFn posteriorPenalty,
ν ω) ≤
∑ g : γ, confidenceOf g := by
classical
have hsubset :
finitePACBayesBernsteinPenaltyBadSamples riskFn empiricalRiskFn posteriorPenalty
finitePACBayesBernsteinGridBadSamples π lambdaOf scaleOf confidenceOf
riskFn empiricalRiskFn varianceProxy := by
intro ω hω
rw [finitePACBayesBernsteinPenaltyBadSamples, Finset.mem_filter] at hω
rcases hω.2 with ⟨ρ, hρ, hbad⟩
rcases hgridCovers ρ hρ with ⟨g, hpenalty⟩
rw [finitePACBayesBernsteinGridBadSamples, Finset.mem_filter]
refine ⟨Finset.mem_univ ω, g, ρ, hρ, ?_⟩
linarith
have hmass_subset :
(∑ ω ∈
finitePACBayesBernsteinPenaltyBadSamples riskFn empiricalRiskFn posteriorPenalty,
ν ω)
∑ ω ∈
finitePACBayesBernsteinGridBadSamples π lambdaOf scaleOf confidenceOf
riskFn empiricalRiskFn varianceProxy,
ν ω := by
exact Finset.sum_le_sum_of_subset_of_nonneg hsubset (by
intro ω _hω _hnot
exact hν.nonneg ω)
exact hmass_subset.trans
(finitePACBayesBernsteinGrid_badEventMass_le_sum_delta
(Ω := Ω) (ι := ι) (γ := γ) hν hπ lambdaOf scaleOf confidenceOf
hlambda hscale hconfidenceOf riskFn empiricalRiskFn varianceProxy hExpected)

/--
Finite-grid optimized Bernstein bad-event bound with an explicit total
confidence budget.
-/
theorem finitePACBayesBernsteinGridOptimized_badEventMass_le_delta
[Fintype Ω] [DecidableEq Ω] [Fintype ι] [Nonempty ι] [Fintype γ]
{ν : Ω → ℝ} (hν : IsPMF ν)
{π : ι → ℝ} (hπ : IsFullSupportPMF π)
(lambdaOf scaleOf confidenceOf : γ → ℝ)
(posteriorPenalty : (ι → ℝ) → ℝ)
(hlambda : ∀ g : γ, 0 < lambdaOf g)
(hscale : ∀ g : γ, scaleOf g * lambdaOf g < 1)
(hconfidenceOf : ∀ g : γ, 0 < confidenceOf g)
(riskFn : ι → ℝ) (empiricalRiskFn : Ω → ι → ℝ)
(varianceProxy : ι → ℝ)
(hgridCovers :
∀ ρ : ι → ℝ, IsPMF ρ →
∃ g : γ,
(klDiv ρ π + Real.log (1 / confidenceOf g)) / lambdaOf g +
lambdaOf g * posteriorMarginVarianceProxy ρ varianceProxy /
(2 * (1 - scaleOf g * lambdaOf g))
≤ posteriorPenalty ρ)
(hExpected :
∀ g : γ,
expectedPriorBernsteinExpMoment ν π (lambdaOf g) (scaleOf g)
riskFn empiricalRiskFn varianceProxy ≤ 1)
{delta : ℝ} (hgridConfidence : (∑ g : γ, confidenceOf g) ≤ delta) :
(∑ ω ∈
finitePACBayesBernsteinPenaltyBadSamples riskFn empiricalRiskFn posteriorPenalty,
ν ω) ≤ delta :=
(finitePACBayesBernsteinGridOptimized_badEventMass_le_sum_delta
(Ω := Ω) (ι := ι) (γ := γ) hν hπ lambdaOf scaleOf confidenceOf
posteriorPenalty hlambda hscale hconfidenceOf riskFn empiricalRiskFn
varianceProxy hgridCovers hExpected).trans hgridConfidence

end

end FormalSLT.PACBayesBernstein
Loading
Loading