Skip to content

Commit

Permalink
Functional API for applying transition matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Nov 15, 2024
1 parent a370596 commit 0947ce1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ function _forward_backward!(
# Backward
β[:, t2] .= c[t2]
for t in (t2 - 1):-1:t1
trans = transition_matrix(hmm, control_seq[t])
Bβ[:, t + 1] .= view(B, :, t + 1) .* view(β, :, t + 1)
mul!(view(β, :, t), trans, view(Bβ, :, t + 1))
lmul!(c[t], view(β, :, t))
βₜ = view(β, :, t)
Bβₜ₊₁ = view(Bβ, :, t + 1)
predict_previous_state!(βₜ, hmm, Bβₜ₊₁, control_seq[t])
lmul!(c[t], βₜ)
end
Bβ[:, t1] .= view(B, :, t1) .* view(β, :, t1)

Expand Down
11 changes: 11 additions & 0 deletions src/inference/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,14 @@ function predict_next_state!(
mul!(next_state_marginals, transpose(trans), current_state_marginals)
return next_state_marginals
end

function predict_previous_state!(
previous_state_marginals::AbstractVector{<:Real},
hmm::AbstractHMM,
current_state_marginals::AbstractVector{<:Real},
control=nothing,
)
trans = transition_matrix(hmm, control)
mul!(previous_state_marginals, trans, current_state_marginals)
return previous_state_marginals
end

0 comments on commit 0947ce1

Please sign in to comment.