diff --git a/src/distributions/truncate.jl b/src/distributions/truncate.jl index ccb0d06..c9ad901 100644 --- a/src/distributions/truncate.jl +++ b/src/distributions/truncate.jl @@ -19,6 +19,7 @@ realtype(::Type{TruncatedDistribution{D, T1, T2}}) where {D, T1, T2} = promote_t realtype(::TruncatedDistribution{D, T1, T2}) where {D, T1, T2} = promote_type(realtype(D), T1, T2) normalization_constant(d::TruncatedDistribution) = cdf(get_dist(d), get_upper(d)) - cdf(get_dist(d), get_lower(d)) + pmf(get_dist(d), get_lower(d)) +lognormalization_constant(d::TruncatedDistribution) = logaddexp(logsubexp(logcdf(get_dist(d), get_upper(d)), logcdf(get_dist(d), get_lower(d))), logpmf(get_dist(d), get_lower(d))) function pdf(d::TruncatedDistribution, x::T) where { T <: Real } Tx = promote_type(T, realtype(d)) @@ -36,6 +37,15 @@ function cdf(d::TruncatedDistribution{D}, x::T) where { T, D } end return (cdf(get_dist(d), x) - cdf(get_dist(d), get_lower(d))) / normalization_constant(d) end +function logcdf(d::TruncatedDistribution, x::T) where { T <: Real } + Tx = promote_type(T, realtype(d)) + if x <= get_lower(d) + return -Inf + elseif x > get_upper(d) + return zero(Tx) + end + return logsubexp(logcdf(get_dist(d), x), logcdf(get_dist(d), get_lower(d))) - lognormalization_constant(d) +end function invcdf(d::TruncatedDistribution, p) return invcdf(get_dist(d), cdf(get_dist(d), get_lower(d)) + p * normalization_constant(d)) end diff --git a/test/distributions/truncate_tests.jl b/test/distributions/truncate_tests.jl index 24eb435..906eaab 100644 --- a/test/distributions/truncate_tests.jl +++ b/test/distributions/truncate_tests.jl @@ -1,6 +1,6 @@ @testitem "Truncated distributions" begin - using UnboundedBNN: TruncatedDistribution, pdf, cdf, Normal, normalization_constant, invcdf, get_dist, get_lower, get_upper, truncate, expand_truncation_to_ints, realtype, support + using UnboundedBNN: TruncatedDistribution, pdf, logcdf, cdf, Normal, lognormalization_constant, normalization_constant, invcdf, get_dist, get_lower, get_upper, truncate, expand_truncation_to_ints, realtype, support @testset "get" begin @test get_dist(TruncatedDistribution(Normal(0, 1), 0, Inf)) == Normal(0, 1) @@ -46,6 +46,14 @@ @test normalization_constant(TruncatedDistribution(Poisson(1), 0, 3)) ≈ sum(pmf.(Ref(Poisson(1)), 0:3)) @test normalization_constant(TruncatedDistribution(Poisson(1), 1, 3)) ≈ sum(pmf.(Ref(Poisson(1)), 1:3)) end + + @testset "lognormalization_constant" begin + @test lognormalization_constant(TruncatedDistribution(Normal(0, 1), -Inf, Inf)) ≈ 0 + @test lognormalization_constant(TruncatedDistribution(Normal(0, 1), -Inf, 0)) ≈ log(0.5) + @test lognormalization_constant(TruncatedDistribution(Normal(2,3), 2, Inf)) ≈ log(0.5) + @test lognormalization_constant(TruncatedDistribution(Poisson(1), 0, 3)) ≈ log(sum(pmf.(Ref(Poisson(1)), 0:3))) + @test lognormalization_constant(TruncatedDistribution(Poisson(1), 1, 3)) ≈ log(sum(pmf.(Ref(Poisson(1)), 1:3))) + end @testset "pdf" begin @test pdf(TruncatedDistribution(Normal(0, 1), 0, Inf), -1) ≈ 0 @@ -61,12 +69,27 @@ @test pdf(TruncatedDistribution(Normal(2,3), 2, Inf), 2) ≈ 2 / sqrt(2π) / 3 @test pdf(TruncatedDistribution(Normal(2,3), 2, Inf), 3) ≈ 2 / sqrt(2π) / 3 * exp(-0.5/9) end + + @testset "pmf" begin + @test pmf(TruncatedDistribution(Poisson(1), 0, 3), -1) ≈ 0 + end @testset "cdf" begin @test cdf(TruncatedDistribution(Normal(0, 1), 0, Inf), -1) ≈ 0 @test cdf(TruncatedDistribution(Normal(0, 1), 0, Inf), 0) ≈ 0 @test cdf(TruncatedDistribution(Normal(0, 1), 0, Inf), 2) ≈ 1 - 2 * cdf(Normal(0, 1), -2) @test cdf(TruncatedDistribution(Normal(0, 1), 0, Inf), Inf) ≈ 1 + @test cdf(TruncatedDistribution(Normal(0, 1), -Inf, 0), 0) ≈ 1 + @test cdf(TruncatedDistribution(Normal(0, 1), -Inf, 0), 1) ≈ 1 + end + + @testset "logcdf" begin + @test logcdf(TruncatedDistribution(Normal(0, 1), 0, Inf), -1) ≈ -Inf + @test logcdf(TruncatedDistribution(Normal(0, 1), 0, Inf), 0) ≈ -Inf + @test logcdf(TruncatedDistribution(Normal(0, 1), 0, Inf), 2) ≈ log(1 - 2 * cdf(Normal(0, 1), -2)) + @test logcdf(TruncatedDistribution(Normal(0, 1), 0, Inf), Inf) ≈ 0 + @test logcdf(TruncatedDistribution(Normal(0, 1), -Inf, 0), 0) ≈ 0 + @test logcdf(TruncatedDistribution(Normal(0, 1), -Inf, 0), 1) ≈ 0 end @testset "inv_cdf" begin @@ -76,7 +99,7 @@ @test invcdf(TruncatedDistribution(Normal(3, 5), -1, 7), 0.5) ≈ 3 @test invcdf(TruncatedDistribution(Normal(3, 5), -1, 7), 1.0) ≈ 7 end - + @testset "truncate" begin @test truncate(Normal(0, 1), 0, Inf) == TruncatedDistribution(Normal(0, 1), 0, Inf) @test truncate(TruncatedDistribution(Normal(0, 1), 0, Inf), 0, Inf) == TruncatedDistribution(Normal(0, 1), 0, Inf) @@ -86,6 +109,10 @@ @test truncate(TruncatedDistribution(Normal(2, 3), 3, 8), 2, Inf) == TruncatedDistribution(Normal(2, 3), 3, 8.0) end + @testset "truncate_to_quantiles" begin + @test truncate_to_quantiles(Normal(0,1), 0, 1) == TruncatedDistribution(Normal(0, 1), -Inf, Inf) + end + @testset "expand_truncation_to_ints" begin @test expand_truncation_to_ints(TruncatedDistribution(Normal(0, 1), 0, 5.4)) == TruncatedDistribution(Normal(0, 1), 0, 6) @test expand_truncation_to_ints(TruncatedDistribution(Normal(0, 1), -5.4, 0)) == TruncatedDistribution(Normal(0, 1), -6, 0)