Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable multithreading when seq_ends is passed as a tuple #113

Merged
merged 3 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function StatsAPI.fit!(
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
N = length(hmm)
Expand Down
2 changes: 1 addition & 1 deletion examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ function StatsAPI.fit!(
hmm::PriorHMM,
fb_storage::HiddenMarkovModels.ForwardBackwardStorage,
obs_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
)
## initialize to defaults without observations
hmm.init .= 0
Expand Down
2 changes: 1 addition & 1 deletion examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function StatsAPI.fit!(
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
L, N = period(hmm), length(hmm)
Expand Down
1 change: 1 addition & 0 deletions libs/HMMTest/src/HMMTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module HMMTest

using BenchmarkTools: @ballocated
using HiddenMarkovModels
using HiddenMarkovModels: AbstractVectorOrNTuple
import HiddenMarkovModels as HMMs
using HMMBase: HMMBase
using JET: @test_opt, @test_call
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function test_allocations(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
@testset "Allocations" begin
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/coherence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function test_coherent_algorithms(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
atol::Real=0.05,
init::Bool=true,
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function test_type_stability(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
@testset "Type stability" begin
Expand Down
8 changes: 4 additions & 4 deletions src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function baum_welch!(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
atol::Real,
max_iterations::Integer,
loglikelihood_increasing::Bool,
Expand Down Expand Up @@ -55,7 +55,7 @@ function baum_welch(
hmm_guess::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
atol=1e-5,
max_iterations=100,
loglikelihood_increasing=true,
Expand All @@ -73,7 +73,7 @@ function baum_welch(
seq_ends,
atol,
max_iterations,
loglikelihood_increasing=false,
loglikelihood_increasing,
)
return hmm, logL_evolution
end
Expand All @@ -85,7 +85,7 @@ function StatsAPI.fit!(
fb_storage::ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
return fit!(hmm, fb_storage, obs_seq; seq_ends)
end
4 changes: 2 additions & 2 deletions src/inference/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function _params_and_loglikelihoods(
hmm::AbstractHMM,
obs_seq::Vector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
init = initialization(hmm)
trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t
Expand All @@ -22,7 +22,7 @@ function ChainRulesCore.rrule(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
_, pullback = rrule_via_ad(
rc, _params_and_loglikelihoods, hmm, obs_seq, control_seq; seq_ends
Expand Down
51 changes: 43 additions & 8 deletions src/inference/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,35 @@
c::Vector{R}
end

"""
$(TYPEDEF)

# Fields

Only the fields with a description are part of the public API.

$(TYPEDFIELDS)
"""
struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}}
"posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`"
γ::Matrix{R}
"posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`"
ξ::Vector{M}
"one loglikelihood per observation sequence"
logL::Vector{R}
B::Matrix{R}
α::Matrix{R}
c::Vector{R}
β::Matrix{R}
Bβ::Matrix{R}
end

Base.eltype(::ForwardStorage{R}) where {R} = R
Base.eltype(::ForwardBackwardStorage{R}) where {R} = R

const ForwardOrForwardBackwardStorage{R} = Union{
ForwardStorage{R},ForwardBackwardStorage{R}
}

"""
$(SIGNATURES)
Expand All @@ -25,7 +53,7 @@
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
Expand All @@ -40,7 +68,7 @@
$(SIGNATURES)
"""
function forward!(
storage,
storage::ForwardOrForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector,
Expand Down Expand Up @@ -88,16 +116,23 @@
$(SIGNATURES)
"""
function forward!(
storage,
storage::ForwardOrForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
(; logL) = storage
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;)
end

Check warning on line 135 in src/inference/forward.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/forward.jl#L135

Added line #L135 was not covered by tests
end
return nothing
end
Expand All @@ -113,7 +148,7 @@
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends)
forward!(storage, hmm, obs_seq, control_seq; seq_ends)
Expand Down
54 changes: 19 additions & 35 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,11 @@
"""
$(TYPEDEF)

# Fields

Only the fields with a description are part of the public API.

$(TYPEDFIELDS)
"""
struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}}
"posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`"
γ::Matrix{R}
"posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`"
ξ::Vector{M}
"one loglikelihood per observation sequence"
logL::Vector{R}
B::Matrix{R}
α::Matrix{R}
c::Vector{R}
β::Matrix{R}
Bβ::Matrix{R}
end

Base.eltype(::ForwardBackwardStorage{R}) where {R} = R

"""
$(SIGNATURES)
"""
function initialize_forward_backward(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals=true,
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
Expand Down Expand Up @@ -100,20 +75,29 @@
$(SIGNATURES)
"""
function forward_backward!(
storage::ForwardBackwardStorage{R},
storage::ForwardBackwardStorage,
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
transition_marginals::Bool=true,
) where {R}
)
(; logL) = storage
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
)
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = forward_backward!(
storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals
)
end
end

Check warning on line 100 in src/inference/forward_backward.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/forward_backward.jl#L99-L100

Added lines #L99 - L100 were not covered by tests
return nothing
end

Expand All @@ -128,7 +112,7 @@
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
transition_marginals = false
storage = initialize_forward_backward(
Expand Down
4 changes: 2 additions & 2 deletions src/inference/logdensity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function DensityInterface.logdensityof(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
_, logL = forward(hmm, obs_seq, control_seq; seq_ends)
return sum(logL)
Expand All @@ -23,7 +23,7 @@ function joint_logdensityof(
obs_seq::AbstractVector,
state_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
R = eltype(hmm, obs_seq[1], control_seq[1])
logL = zero(R)
Expand Down
19 changes: 13 additions & 6 deletions src/inference/viterbi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function initialize_viterbi(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
N, T, K = length(hmm), length(obs_seq), length(seq_ends)
R = eltype(hmm, obs_seq[1], control_seq[1])
Expand Down Expand Up @@ -85,12 +85,19 @@ function viterbi!(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
) where {R}
(; logL) = storage
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
if seq_ends isa NTuple
for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
end
else
@threads for k in eachindex(seq_ends)
t1, t2 = seq_limits(seq_ends, k)
logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;)
end
end
return nothing
end
Expand All @@ -106,7 +113,7 @@ function viterbi(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
storage = initialize_viterbi(hmm, obs_seq, control_seq; seq_ends)
viterbi!(storage, hmm, obs_seq, control_seq; seq_ends)
Expand Down
2 changes: 1 addition & 1 deletion src/types/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function StatsAPI.fit!(
hmm::HMM,
fb_storage::ForwardBackwardStorage,
obs_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
(; γ, ξ) = fb_storage
# Fit states
Expand Down
2 changes: 1 addition & 1 deletion src/utils/limits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ $(SIGNATURES)

Return a tuple `(t1, t2)` giving the begin and end indices of subsequence `k` within a set of sequences ending at `seq_ends`.
"""
function seq_limits(seq_ends::AbstractVector{Int}, k::Integer)
function seq_limits(seq_ends::AbstractVectorOrNTuple{Int}, k::Integer)
if k == 1
return 1, seq_ends[k]
else
Expand Down
2 changes: 2 additions & 0 deletions src/utils/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const AbstractVectorOrNTuple{T} = Union{AbstractVector{T},NTuple{N,T}} where {N}

sum_to_one!(x) = ldiv!(sum(x), x)

mynonzeros(x::AbstractArray) = x
Expand Down
Loading