Viterbi and Numerical Optimizations

5 minute read

Published:

I’m quite excited to write about my recent set of investigations on a seemingly simple mini-project. Recently I wrote a small code for Machine Translation from English to Hindi using HMM Viterbi decoding. First I did word alignments using the GIZA++ tool from parallel corpus and after aligning the words in both languages, we get probability scores like: Please note that this is not exactly a post on Machine Translation, as this is a very naive way to do that, but a discussion on using HMM POS-Tagger code for a crude MT project and using basic level optimizations to make it fast.

Src lang indexDst lang indexProbability
060.0105564
090.0888391
0110.123789
0128.75807e-06
0140.000144811
0173.12898e-05

Here the integers represent the index for each word in the dictionary created from the corpus.

The Machine Translation model is analogous to an HMM POS-Tagger where words in one language can be treated as hidden states transitioning from one to another( and defined by the language model ) and each word( state ) generating a word in the target language using emission probability given by the alignments above.

The noisy channel model:

P(H|E) = argmax(over H) P(E|H) * P(H)

i.e, Probability of Hindi given an English sentence is argmax over the combined probability of English given Hindi and probability of the hindi sentence, maxed over hindi sentences.

Viterbi Decoding

Once this analogy between POS-Tagging and Machine Translation is setup, I simply ported the POS-Tagger code for Machine Translation, but there were surprises to be had. I ran the algorithm and it would take ages to translate one sentence. Lets see what was wrong:

Initialization

SEQSCORE(1,1)=1.0
BACKPTR(1,1)=0
For(i=2 to N) do
SEQSCORE(i,1)=0.0
[expressing the fact that first state is S 1 ]

Iteration

For(t=2 to T) do
For(i=1 to N) do
	SEQSCORE(i,t) = Max (j=1,N)
	[ SEQSCORE( j , ( t − 1 )) * P ( Sj --ak--> Si)]
BACKPTR(I,t) = index j that gives the MAX above

Complexity: O(T x N x N)

POS-Tagger: Hidden states = tags = 20(say), words = 5000, Complexity = 5 x 10^8

Machine Translation: Hidden states = Hindi words = 94125~10^5, English words = 72167~10^5, Complexity ~ 10^10

Clearly, because the number of hidden states in Machine Translation is equivalent to the number of words, our complexity is of that order of magniture higher.

Can we do better?

Attempt one: Base approach

To start with I had a transition matrix that contained transitions of every word to every other words, which I readily indexed using ‘transition[w_idx1]’[w_idx2]. This was the naivest approach and which did not even meet memory constraints, letting my laptop hang.

To solve this, I decided to use a scipy sparse matrix, and after some research, given the way my transition was structured, I decided to go ahead with csr sparse matrix

Attempt two

After discussion with one of my friends, it came to my mind that the above algorithm naively checks transition from every state to every other state, whereas in a real language one word leads two only a handful number of other words in the language. I was unnecessarily iterating over null transitions!

Lets take a look at the code:

# 't' is the current word
for idx1, i in enumerate(transition[0]):
	maxScore = seqscore[0][t]                                       
	for idx2, j in enumerate(transition[0]):                        
		score = seqscore[j][t-1] * transition[idx2][idx1] * emission[wordindex][idx1]                                                               

Clearly the transition evaluation needs some love, I googled for ways to only get nonzero rows in transition matrix, and found find, so lets see:

tr_row = transition.getrow(idx1)
.
.
.
rs, indices, vals = find(tr_row)
for idx2, v2 in zip(indices, vals):

I did get some nasty code in the inner loop, but thats what was, I had to tradeoff an easy to understand implementation for speed.

As a result, I did get an instant improvement in speed! The search method which would take ages for one sentence now took 5 minutes. However I was still not satisfied, so I decided to dig further, this time using tech ;)

Attempt three

I remembered that python had a profiling library so decided to take up actual profiling

Using

python -m cProfile [-o output_file] [-s sort_order] myscript.py

I could save my stats to a file, which I later retrieved as:

import pstats
p = pstats.Stats('stats')
p.sort_stats('cumulative').print_stats(50)

The result, after ordering by cumulative time:

   Ordered by: cumulative time
   List reduced from 2002 to 50 due to restriction <50>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    557/1    0.040    0.000  197.171  197.171 {built-in method builtins.exec}
        1    0.005    0.005  197.171  197.171 mt.py:1(<module>)
        1    0.000    0.000  196.803  196.803 mt.py:122(main)
        1    0.031    0.031  196.803  196.803 mt.py:100(train)
        1   17.758   17.758  182.292  182.292 mt.py:51(search)
   564750    4.789    0.000  112.658    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/extract.py:14(find)
1129511/564761    7.526    0.000   86.061    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/coo.py:118(__init__)
   564757    5.523    0.000   53.782    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/compressed.py:905(tocoo)
  1129511   20.949    0.000   52.660    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/coo.py:212(_check)
   564757    0.943    0.000   51.593    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/csr.py:356(getrow)
   564757    3.779    0.000   50.650    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/csr.py:411(_get_submatrix)
564770/564764    4.156    0.000   40.853    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/compressed.py:24(__init__)

From above data, it is clear that ‘scipy/sparse/coo.py’ was taking the largest number of individual calls. After some digging, I found that though I’m reducing the number of transitions by only evaluting non-zero transitions, the method ‘get_row’ from scipy for getting individual rows of a sparse matrix is in itself quite intensive

Wrap-up

Whoa, so much for a simple Machine Translation task! but it was a good exercise that me familiar with sparse matrices, actual optimizations in terms of memory and time at the data structure level and profiling!