Skip to content

Commit

Permalink
move KL div normals
Browse files Browse the repository at this point in the history
  • Loading branch information
bartvanerp committed Aug 23, 2024
1 parent 406dda8 commit c123320
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 5 additions & 1 deletion src/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,8 @@ support(::SafeNormal) = -Inf, Inf
pdf(d::UnionNormal, x::Real) = normpdf(get_μ(d), get_σ(d), x)
logpdf(d::UnionNormal, x::Real) = normlogpdf(get_μ(d), get_σ(d), x)
cdf(d::UnionNormal, x::Real) = normcdf(get_μ(d), get_σ(d), x)
invcdf(d::UnionNormal, p::Real) = norminvcdf(get_μ(d), get_σ(d), p)
invcdf(d::UnionNormal, p::Real) = norminvcdf(get_μ(d), get_σ(d), p)

KL_loss(d1::UnionNormal, d2::UnionNormal) = KL_normals(get_μ(d1), get_σ(d1)^2, get_μ(d2), get_σ(d2)^2)
KL_normals(m, v) = KL_normals(m, v, 0, 1)
KL_normals(pm, pv, qm, qv) = (log(qv/pv) + (pv + abs2(pm - qm) )/qv - 1)/2
3 changes: 0 additions & 3 deletions src/layers/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ function (l::Linear)(x_mean, x_var)
return y_mean, y_var
end

KL_normals(m, v) = KL_normals(m, v, 0, 1)
KL_normals(pm, pv, qm, qv) = (log(qv/pv) + (pv + abs2(pm - qm) )/qv - 1)/2

function KL_loss(l::Linear)
# ASSUMES STANDARD NORMAL PRIOR
W_var = softplus.(l.W_wvar)
Expand Down

0 comments on commit c123320

Please sign in to comment.