# Viterbi and Numerical Optimizations

Published:

I’m quite excited to write about my recent set of investigations on a seemingly simple mini-project. Recently I wrote a small code for Machine Translation from English to Hindi using HMM Viterbi decoding. First I did word alignments using the GIZA++ tool from parallel corpus and after aligning the words in both languages, we get probability scores like: Please note that this is not exactly a post on Machine Translation, as this is a very naive way to do that, but a discussion on using HMM POS-Tagger code for a crude MT project and using basic level optimizations to make it fast.

Src lang indexDst lang indexProbability
060.0105564
090.0888391
0110.123789
0128.75807e-06
0140.000144811
0173.12898e-05

Here the integers represent the index for each word in the dictionary created from the corpus.

The Machine Translation model is analogous to an HMM POS-Tagger where words in one language can be treated as hidden states transitioning from one to another( and defined by the language model ) and each word( state ) generating a word in the target language using emission probability given by the alignments above.

The noisy channel model:

P(H|E) = argmax(over H) P(E|H) * P(H)


i.e, Probability of Hindi given an English sentence is argmax over the combined probability of English given Hindi and probability of the hindi sentence, maxed over hindi sentences.

## Viterbi Decoding

Once this analogy between POS-Tagging and Machine Translation is setup, I simply ported the POS-Tagger code for Machine Translation, but there were surprises to be had. I ran the algorithm and it would take ages to translate one sentence. Lets see what was wrong:

#### Initialization

SEQSCORE(1,1)=1.0
BACKPTR(1,1)=0
For(i=2 to N) do
SEQSCORE(i,1)=0.0
[expressing the fact that first state is S 1 ]


#### Iteration

For(t=2 to T) do
For(i=1 to N) do
SEQSCORE(i,t) = Max (j=1,N)
[ SEQSCORE( j , ( t − 1 )) * P ( Sj --ak--> Si)]
BACKPTR(I,t) = index j that gives the MAX above


Complexity: O(T x N x N)

POS-Tagger: Hidden states = tags = 20(say), words = 5000, Complexity = 5 x 10^8

Machine Translation: Hidden states = Hindi words = 94125~10^5, English words = 72167~10^5, Complexity ~ 10^10

Clearly, because the number of hidden states in Machine Translation is equivalent to the number of words, our complexity is of that order of magniture higher.

## Attempt one: Base approach

To start with I had a transition matrix that contained transitions of every word to every other words, which I readily indexed using ‘transition[w_idx1]’[w_idx2]. This was the naivest approach and which did not even meet memory constraints, letting my laptop hang.

To solve this, I decided to use a scipy sparse matrix, and after some research, given the way my transition was structured, I decided to go ahead with csr sparse matrix

## Attempt two

After discussion with one of my friends, it came to my mind that the above algorithm naively checks transition from every state to every other state, whereas in a real language one word leads two only a handful number of other words in the language. I was unnecessarily iterating over null transitions!

Lets take a look at the code:

# 't' is the current word
for idx1, i in enumerate(transition[0]):
maxScore = seqscore[0][t]
for idx2, j in enumerate(transition[0]):
score = seqscore[j][t-1] * transition[idx2][idx1] * emission[wordindex][idx1]


Clearly the transition evaluation needs some love, I googled for ways to only get nonzero rows in transition matrix, and found find, so lets see:

tr_row = transition.getrow(idx1)
.
.
.
rs, indices, vals = find(tr_row)
for idx2, v2 in zip(indices, vals):


I did get some nasty code in the inner loop, but thats what was, I had to tradeoff an easy to understand implementation for speed.

As a result, I did get an instant improvement in speed! The search method which would take ages for one sentence now took 5 minutes. However I was still not satisfied, so I decided to dig further, this time using tech ;)

## Attempt three

I remembered that python had a profiling library so decided to take up actual profiling

Using

python -m cProfile [-o output_file] [-s sort_order] myscript.py


I could save my stats to a file, which I later retrieved as:

import pstats
p = pstats.Stats('stats')
p.sort_stats('cumulative').print_stats(50)


The result, after ordering by cumulative time:

   Ordered by: cumulative time
List reduced from 2002 to 50 due to restriction <50>

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
557/1    0.040    0.000  197.171  197.171 {built-in method builtins.exec}
1    0.005    0.005  197.171  197.171 mt.py:1(<module>)
1    0.000    0.000  196.803  196.803 mt.py:122(main)
1    0.031    0.031  196.803  196.803 mt.py:100(train)
1   17.758   17.758  182.292  182.292 mt.py:51(search)
564750    4.789    0.000  112.658    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/extract.py:14(find)
1129511/564761    7.526    0.000   86.061    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/coo.py:118(__init__)
564757    5.523    0.000   53.782    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/compressed.py:905(tocoo)
1129511   20.949    0.000   52.660    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/coo.py:212(_check)
564757    0.943    0.000   51.593    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/csr.py:356(getrow)
564757    3.779    0.000   50.650    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/csr.py:411(_get_submatrix)
564770/564764    4.156    0.000   40.853    0.000 /usr/lib/python3.6/site-packages/scipy/sparse/compressed.py:24(__init__)



From above data, it is clear that ‘scipy/sparse/coo.py’ was taking the largest number of individual calls. After some digging, I found that though I’m reducing the number of transitions by only evaluting non-zero transitions, the method ‘get_row’ from scipy for getting individual rows of a sparse matrix is in itself quite intensive

## Wrap-up

Whoa, so much for a simple Machine Translation task! but it was a good exercise that me familiar with sparse matrices, actual optimizations in terms of memory and time at the data structure level and profiling!

Tags: