Skip to content

Commit

Permalink
Merge pull request #221 from biaslab/poisson_fix
Browse files Browse the repository at this point in the history
Fix poisson node
  • Loading branch information
ThijsvdLaar authored May 3, 2023
2 parents 17f2755 + 71c0894 commit a4e9741
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
34 changes: 28 additions & 6 deletions src/factor_nodes/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,42 @@ logPdf(V::Type{Univariate}, F::Type{Poisson}, x::Number; η::Vector) = -logfacto
# ∑ [λ^k*log(k!)]/k! from k=0 to inf
# Approximates the above sum for calculation of averageEnergy and differentialEntropy
# @ref https://arxiv.org/pdf/1708.06394.pdf
function apprSum(l, j=100)
sum([(l)^(k)*logfactorial(k)/exp(logfactorial(k)) for k in collect(0:j)])
function approximatePowerSum(l, j=150)
(l == 0.0) && return 0.0
(l > 110.0) && error("Cannot approximate power sum for Poisson distribution with l>110")

s = zero(BigFloat)
lk = one(BigFloat)
for k = 1:j
lk *= l
s += lk*loggamma(k + 1)/gamma(k + 1)
end

return convert(Float64, s)
end

# Entropy functional
# @ref https://en.wikipedia.org/wiki/Poisson_distribution
function differentialEntropy(dist::Distribution{Univariate, Poisson})
l = clamp(dist.params[:l], tiny, huge)
l*(1-log(l)) + exp(-l)*apprSum(l)
l = dist.params[:l]
(l == 0.0) && return 0.0

if l <= 50.0
return l*(1-log(l)) + exp(-l)*approximatePowerSum(l)
else
return 0.5*log(2*pi**l) - 1/(12*l) - 1/(24*l^2) - 19/(360*l^3)
end
end

# Average energy functional
# Average energy functionals
function averageEnergy(::Type{Poisson}, marg_out::Distribution{Univariate}, marg_l::Distribution{Univariate})
unsafeMean(marg_l) -
unsafeMean(marg_out)*unsafeLogMean(marg_l) +
exp(-unsafeMean(marg_out))*apprSum(unsafeMean(marg_out))
exp(-unsafeMean(marg_out))*approximatePowerSum(unsafeMean(marg_out))
end

function averageEnergy(::Type{Poisson}, marg_out::Distribution{Univariate, PointMass}, marg_l::Distribution{Univariate})
unsafeMean(marg_l) -
unsafeMean(marg_out)*unsafeLogMean(marg_l) +
sum(log.(1:unsafeMean(marg_out)))
end
5 changes: 4 additions & 1 deletion test/factor_nodes/test_poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ end
@testset "averageEnergy and differentialEntropy" begin
@test isapprox(differentialEntropy(Distribution(Poisson, l=1.0)), averageEnergy(Poisson, Distribution(Poisson, l=1.0), Distribution(Univariate, PointMass, m=1.0)))
@test isapprox(differentialEntropy(Distribution(Poisson, l=10.0)), averageEnergy(Poisson, Distribution(Poisson, l=10.0), Distribution(Univariate, PointMass, m=10.0)))
@test isapprox(differentialEntropy(Distribution(Poisson, l=100.0)), averageEnergy(Poisson, Distribution(Poisson, l=100.0), Distribution(Univariate, PointMass, m=100.0)))
@test isapprox(differentialEntropy(Distribution(Poisson, l=100.0)), averageEnergy(Poisson, Distribution(Poisson, l=100.0), Distribution(Univariate, PointMass, m=100.0)), atol=0.1)

@test averageEnergy(Poisson, Distribution(PointMass, m=1.0), Distribution(Univariate, PointMass, m=1.0)) == 1.0
@test averageEnergy(Poisson, Distribution(PointMass, m=2.0), Distribution(Univariate, PointMass, m=1.0)) == 1.0 + log(2)
end

end # module
2 changes: 1 addition & 1 deletion test/test_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ using LinearAlgebra: Diagonal, isposdef, I, Hermitian
@test bar(1, Tuple{Float64, Float32}((1.0, 1.0f0))) === bar(Tuple{Float64, Float32}((1.0, 1.0f0)), 1)
@test bar(1, Tuple{Float32, Float64}((1.0f0, 1.0))) === bar(Tuple{Float32, Float64}((1.0f0, 1.0)), 1)

@symmetrical function baz(a::Int, b::Float64, c::String) where A where B where C
@symmetrical function baz(a::Int, b::Float64, c::String)
return 1
end

Expand Down

0 comments on commit a4e9741

Please sign in to comment.