Skip to content

Commit

Permalink
improve robustness truncated
Browse files Browse the repository at this point in the history
  • Loading branch information
bartvanerp committed Aug 27, 2024
1 parent 9b500ef commit 678c81d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
10 changes: 10 additions & 0 deletions src/distributions/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ realtype(::Type{TruncatedDistribution{D, T1, T2}}) where {D, T1, T2} = promote_t
realtype(::TruncatedDistribution{D, T1, T2}) where {D, T1, T2} = promote_type(realtype(D), T1, T2)

normalization_constant(d::TruncatedDistribution) = cdf(get_dist(d), get_upper(d)) - cdf(get_dist(d), get_lower(d)) + pmf(get_dist(d), get_lower(d))
lognormalization_constant(d::TruncatedDistribution) = logaddexp(logsubexp(logcdf(get_dist(d), get_upper(d)), logcdf(get_dist(d), get_lower(d))), logpmf(get_dist(d), get_lower(d)))

function pdf(d::TruncatedDistribution, x::T) where { T <: Real }
Tx = promote_type(T, realtype(d))
Expand All @@ -36,6 +37,15 @@ function cdf(d::TruncatedDistribution{D}, x::T) where { T, D }
end
return (cdf(get_dist(d), x) - cdf(get_dist(d), get_lower(d))) / normalization_constant(d)
end
function logcdf(d::TruncatedDistribution, x::T) where { T <: Real }
Tx = promote_type(T, realtype(d))
if x <= get_lower(d)
return -Inf
elseif x > get_upper(d)
return zero(Tx)
end
return logsubexp(logcdf(get_dist(d), x), logcdf(get_dist(d), get_lower(d))) - lognormalization_constant(d)
end
function invcdf(d::TruncatedDistribution, p)
return invcdf(get_dist(d), cdf(get_dist(d), get_lower(d)) + p * normalization_constant(d))
end
Expand Down
31 changes: 29 additions & 2 deletions test/distributions/truncate_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testitem "Truncated distributions" begin

using UnboundedBNN: TruncatedDistribution, pdf, cdf, Normal, normalization_constant, invcdf, get_dist, get_lower, get_upper, truncate, expand_truncation_to_ints, realtype, support
using UnboundedBNN: TruncatedDistribution, pdf, logcdf, cdf, Normal, lognormalization_constant, normalization_constant, invcdf, get_dist, get_lower, get_upper, truncate, expand_truncation_to_ints, realtype, support

@testset "get" begin
@test get_dist(TruncatedDistribution(Normal(0, 1), 0, Inf)) == Normal(0, 1)
Expand Down Expand Up @@ -46,6 +46,14 @@
@test normalization_constant(TruncatedDistribution(Poisson(1), 0, 3)) sum(pmf.(Ref(Poisson(1)), 0:3))
@test normalization_constant(TruncatedDistribution(Poisson(1), 1, 3)) sum(pmf.(Ref(Poisson(1)), 1:3))
end

@testset "lognormalization_constant" begin
@test lognormalization_constant(TruncatedDistribution(Normal(0, 1), -Inf, Inf)) 0
@test lognormalization_constant(TruncatedDistribution(Normal(0, 1), -Inf, 0)) log(0.5)
@test lognormalization_constant(TruncatedDistribution(Normal(2,3), 2, Inf)) log(0.5)
@test lognormalization_constant(TruncatedDistribution(Poisson(1), 0, 3)) log(sum(pmf.(Ref(Poisson(1)), 0:3)))
@test lognormalization_constant(TruncatedDistribution(Poisson(1), 1, 3)) log(sum(pmf.(Ref(Poisson(1)), 1:3)))
end

@testset "pdf" begin
@test pdf(TruncatedDistribution(Normal(0, 1), 0, Inf), -1) 0
Expand All @@ -61,12 +69,27 @@
@test pdf(TruncatedDistribution(Normal(2,3), 2, Inf), 2) 2 / sqrt(2π) / 3
@test pdf(TruncatedDistribution(Normal(2,3), 2, Inf), 3) 2 / sqrt(2π) / 3 * exp(-0.5/9)
end

@testset "pmf" begin
@test pmf(TruncatedDistribution(Poisson(1), 0, 3), -1) 0
end

@testset "cdf" begin
@test cdf(TruncatedDistribution(Normal(0, 1), 0, Inf), -1) 0
@test cdf(TruncatedDistribution(Normal(0, 1), 0, Inf), 0) 0
@test cdf(TruncatedDistribution(Normal(0, 1), 0, Inf), 2) 1 - 2 * cdf(Normal(0, 1), -2)
@test cdf(TruncatedDistribution(Normal(0, 1), 0, Inf), Inf) 1
@test cdf(TruncatedDistribution(Normal(0, 1), -Inf, 0), 0) 1
@test cdf(TruncatedDistribution(Normal(0, 1), -Inf, 0), 1) 1
end

@testset "logcdf" begin
@test logcdf(TruncatedDistribution(Normal(0, 1), 0, Inf), -1) -Inf
@test logcdf(TruncatedDistribution(Normal(0, 1), 0, Inf), 0) -Inf
@test logcdf(TruncatedDistribution(Normal(0, 1), 0, Inf), 2) log(1 - 2 * cdf(Normal(0, 1), -2))
@test logcdf(TruncatedDistribution(Normal(0, 1), 0, Inf), Inf) 0
@test logcdf(TruncatedDistribution(Normal(0, 1), -Inf, 0), 0) 0
@test logcdf(TruncatedDistribution(Normal(0, 1), -Inf, 0), 1) 0
end

@testset "inv_cdf" begin
Expand All @@ -76,7 +99,7 @@
@test invcdf(TruncatedDistribution(Normal(3, 5), -1, 7), 0.5) 3
@test invcdf(TruncatedDistribution(Normal(3, 5), -1, 7), 1.0) 7
end

@testset "truncate" begin
@test truncate(Normal(0, 1), 0, Inf) == TruncatedDistribution(Normal(0, 1), 0, Inf)
@test truncate(TruncatedDistribution(Normal(0, 1), 0, Inf), 0, Inf) == TruncatedDistribution(Normal(0, 1), 0, Inf)
Expand All @@ -86,6 +109,10 @@
@test truncate(TruncatedDistribution(Normal(2, 3), 3, 8), 2, Inf) == TruncatedDistribution(Normal(2, 3), 3, 8.0)
end

@testset "truncate_to_quantiles" begin
@test truncate_to_quantiles(Normal(0,1), 0, 1) == TruncatedDistribution(Normal(0, 1), -Inf, Inf)
end

@testset "expand_truncation_to_ints" begin
@test expand_truncation_to_ints(TruncatedDistribution(Normal(0, 1), 0, 5.4)) == TruncatedDistribution(Normal(0, 1), 0, 6)
@test expand_truncation_to_ints(TruncatedDistribution(Normal(0, 1), -5.4, 0)) == TruncatedDistribution(Normal(0, 1), -6, 0)
Expand Down

0 comments on commit 678c81d

Please sign in to comment.