Skip to content

Commit

Permalink
baum-welch algorithm for hmm fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
maximtrp committed Sep 12, 2020
1 parent 2bb3f87 commit b947597
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 68 deletions.
44 changes: 43 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,16 @@ Directed graph of the hidden Markov model:

.. image:: images/hmm.png

Viterbi algorithm
.................

Running Viterbi algorithm on new observations.

.. code:: python
>>> new_obs = "GGCATTGGGCTATAAGAGGAGCTTG"
>>> vs, vsi = a.viterbi(new_obs)
>>> # states sequences obtained with both algorithms
>>> # states sequence
>>> print("VI", "".join(vs))
>>> # observations
>>> print("NO", new_obs)
Expand All @@ -211,3 +214,42 @@ Running Viterbi algorithm on new observations.

VI 0000000001111100000000000
NO GGCATTGGGCTATAAGAGGAGCTTG

Baum-Welch algorithm
....................

Using Baum-Welch algorithm to infer the parameters of a Hidden Markov model:

.. code:: python
>>> obs_seq = 'AGACTGCATATATAAGGGGCAGGCTG'
>>> a = hmm.HiddenMarkovModel().from_baum_welch(obs_seq, states=['0', '1'])
>>> # training log: KL divergence values for all iterations
>>> a.log
::

{
'tp': [0.008646969455670256, 0.0012397829805491124, 0.0003950986109761759],
'ep': [0.09078874423746826, 0.0022734816599056084, 0.0010118204023946836],
'pi': [0.009030829793043593, 0.016658391248503462, 0.0038894983546756065]
}

Inferred transition (`tp`), emission (`ep`) probability matrices and
initial state distribution (`pi`) can be accessed as shown:

.. code:: python
>>> a.ep, a.tp, a.pi
.. code:: python
>>> new_obs = "GGCATTGGGCTATAAGAGGAGCTTG"
>>> vs, vsi = m.viterbi(new_obs)
>>> print("VI", "".join(vs))
>>> print("NO", new_obs)
::

VI 0011100001111100000001100
NO GGCATTGGGCTATAAGAGGAGCTTG
2 changes: 1 addition & 1 deletion mchmm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.3"
__version__ = "0.4.0"

from ._hmm import * # noqa
from ._mc import * # noqa
167 changes: 107 additions & 60 deletions mchmm/_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

class HiddenMarkovModel:

def __init__(self, observations=None, states=None, tp=None, ep=None, pi=None):
def __init__(
self, observations=None, states=None, tp=None, ep=None, pi=None
):
'''Hidden Markov model.
Parameters
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self, observations=None, states=None, tp=None, ep=None, pi=None):
self.pi = np.array(pi)

def _transition_matrix(self, seq=None, states=None):
'''Calculate a transition probability matrix which stores the transition
'''Calculate a transition probability matrix which stores transition
probability of transiting from state i to state j.
Parameters
Expand Down Expand Up @@ -76,7 +78,9 @@ def _transition_matrix(self, seq=None, states=None):
matrix /= matrix.sum(axis=1)[:, None]
return matrix

def _emission_matrix(self, obs_seq=None, states_seq=None, obs=None, states=None):
def _emission_matrix(
self, obs_seq=None, states_seq=None, obs=None, states=None
):
'''Calculate an emission probability matrix.
Parameters
Expand Down Expand Up @@ -140,6 +144,11 @@ def from_seq(self, obs_seq, states_seq, pi=None, end=None, seed=None):
pi_seed : int, optional
Random state used to draw random variates. Passed to
`scipy.stats.uniform` method.
Returns
-------
model : object
Hidden Markov model learned from the given data.
'''

self.obs_seq = np.array(list(obs_seq))
Expand All @@ -150,16 +159,22 @@ def from_seq(self, obs_seq, states_seq, pi=None, end=None, seed=None):
self.ep = self._emission_matrix(self.obs_seq, self.states_seq)

if pi is None:
self.pi = ss.uniform().rvs(size=self.states.size, random_state=seed)
self.pi = ss.uniform().rvs(
size=self.states.size, random_state=seed
)
self.pi /= self.pi.sum()

if end is None:
self.end = ss.uniform().rvs(size=self.states.size, random_state=seed)
self.end = ss.uniform().rvs(
size=self.states.size, random_state=seed
)
self.end /= self.end.sum()

return self

def viterbi(self, obs_seq, obs=None, states=None, tp=None, ep=None, pi=None):
def viterbi(
self, obs_seq, obs=None, states=None, tp=None, ep=None, pi=None
):
'''Viterbi algorithm.
Parameters
Expand Down Expand Up @@ -240,24 +255,27 @@ def s(i):

return x, z

def baum_welch(self, obs_seq, iters=100, obs=None, states=None, tp=None, ep=None,
pi=None, end=None):
def from_baum_welch(
self, obs_seq, states, thres=0.001, obs=None,
tp=None, ep=None, pi=None, end=None
):
'''Baum-Welch algorithm.
Parameters
----------
obs_seq : array_like
Sequence of observations.
iters : int, optional
Number of iterations. Default is 100.
states : array_like, optional
List of states (of size K).
thres : float
Convergence threshold. Kullback-Leibler divergence value below
which model training is stopped.
obs : array_like, optional
Observations space (of size N).
states : array_like, optional
List of states (of size K).
tp : array_like or numpy ndarray, optional
Transition matrix (of size K × K) which stores transition
probability of transiting from state i (row) to state j (col).
Expand All @@ -270,80 +288,109 @@ def baum_welch(self, obs_seq, iters=100, obs=None, states=None, tp=None, ep=None
pi : array_like or numpy ndarray, optional
Initial probabilities array (of size K).
end : array_like or numpy ndarray, optional
Terminal probabilities array (of size K).
Returns
-------
x : numpy ndarray
Sequence of states.
z : numpy ndarray
Sequence of state indices.
model : object
Hidden Markov model trained using Baum-Welch algorithm.
'''

if states is None:
states = self.states
obs_seq = np.array(list(obs_seq))

if obs is None:
obs = np.unique(obs_seq)

K = len(states)
N = len(obs)

if tp is None:
tp = self.tp
tp = np.random.random((K, K))
tp /= tp.sum(axis=1)[:, None]

if ep is None:
ep = self.ep
ep = np.random.random((K, N))
ep /= ep.sum(axis=1)[:, None]

if pi:
pi = np.array(pi)
else:
pi = self.pi

if end:
end = np.array(end)
else:
end = self.end

obs_seq = np.array(list(obs_seq))

if obs is None:
obs = np.unique(obs_seq)
pi = np.random.random(K)
pi /= pi.sum()

T = len(obs_seq)
K = len(states)

def s(i):
return np.argwhere(obs == obs_seq[i]).flatten().item()

alpha = np.zeros((K, T))
beta = np.zeros((K, T))
alpha = np.zeros((T, K))
beta = np.zeros((T, K))
running = True

log = {
'tp': [], 'ep': [], 'pi': []
}

for _ in range(iters):
alpha[:, 0] = pi * ep[:, s(0)]
alpha[:, 0] /= alpha[:, 0].sum()
while running:
alpha[0] = pi * ep[:, s(0)]
alpha[0] /= alpha[0].sum()

for i in range(1, T):
alpha[:, i] = np.sum(alpha[:, i-1] * tp, axis=1) * ep[:, s(i)]
alpha[:, i] /= alpha[:, i].sum()
alpha[i] = np.sum(alpha[i-1] * tp, axis=1) * ep[:, s(i)]
alpha[i] /= alpha[i].sum()

beta[:, T-1] = end * ep[:, s(T-1)]
beta[:, T-1] /= beta[:, T-1].sum()
beta[T-1] = 1
beta[T-1] /= beta[T-1].sum()

for i in reversed(range(T-1)):
beta[:, i] = np.sum(beta[:, i+1] * tp *
ep[:, s(i+1)], axis=1) # i + 1
beta[:, i] /= beta[:, i].sum()
beta[i] = np.sum(
beta[i+1] * tp * ep[:, s(i+1)],
axis=1
) # i + 1
beta[i] /= beta[i].sum()

ksi = np.zeros((T, K, K))
gamma = np.zeros((T, K))

for i in range(T-1):
_t = alpha[:, i] * tp * beta[:, i+1] * ep[:, s(i+1)]
ksi[i] = _t / _t.sum()
ksi[i] = alpha[i] * tp * beta[i+1] * ep[:, s(i+1)]
ksi[i] /= ksi[i].sum()

pi = ksi[0].sum(axis=1)
tp = np.sum(ksi[:-1], axis=0) / ksi[:-1].sum(axis=2).sum(axis=0)
tp /= tp.sum(axis=1)[:, None]
gamma = ksi.sum(axis=2)
for n, ob in enumerate(obs):
ep[:, n] = gamma[np.argwhere(obs_seq == ob).ravel(), :].sum(
axis=0) / gamma.sum(axis=0)
gamma[i] = alpha[i] * beta[i]
gamma[i] /= gamma[i].sum()

y = np.argmax(gamma, axis=1)
x = states[y]
return x, y
_pi = gamma[1]
_tp = np.sum(ksi[:-1], axis=0) / gamma[:-1].sum(axis=0)
_tp /= _tp.sum(axis=1)[:, None]
_ep = np.zeros((K, N))

for n, ob in enumerate(obs):
_ep[:, n] = gamma[
np.argwhere(obs_seq == ob).ravel(), :
].sum(axis=0) / gamma.sum(axis=0)

tp_entropy = ss.entropy(tp.ravel(), _tp.ravel())
ep_entropy = ss.entropy(ep.ravel(), _ep.ravel())
pi_entropy = ss.entropy(pi, _pi)

log['tp'].append(tp_entropy)
log['ep'].append(ep_entropy)
log['pi'].append(pi_entropy)

if tp_entropy < thres and\
ep_entropy < thres and\
pi_entropy < thres:
running = False

ep = _ep.copy()
tp = _tp.copy()
pi = _pi.copy()

if not running:
break

model = self.__class__(
observations=obs, states=states, tp=tp, ep=ep, pi=pi
)
model.obs_seq = obs_seq
model.log = log
return model
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

setup(
name='mchmm',
version='0.3.3',
version='0.4.0',
description='Markov chains and Hidden Markov models',
long_description=open(join(dirname(__file__), 'DESCRIPTION.rst')).read(),
url='http://github.com/maximtrp/mchmm',
Expand All @@ -26,7 +26,6 @@
'License :: OSI Approved :: BSD License',

'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
Expand Down
6 changes: 2 additions & 4 deletions tests/test_mchmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ def test_bw(self):
'''Checking Baum-Welch'''

obs_seq = 'AGACTGCATATATAAGGGGCAGGCTG'
sts_seq = '00000000111111100000000000'
a = hmm.HiddenMarkovModel().from_seq(obs_seq, sts_seq)
b = a.baum_welch(obs_seq, iters=3)
a = hmm.HiddenMarkovModel().from_baum_welch(obs_seq, states=['0', '1'])
self.assertTrue(
isinstance(b[0], np.ndarray) and isinstance(b[1], np.ndarray)
isinstance(a, hmm.HiddenMarkovModel)
)


Expand Down

0 comments on commit b947597

Please sign in to comment.