{-# OPTIONS_GHC -Wall -fno-warn-missing-signatures -O2 #-}

----------------------------------------------------------------
--                                                  ~ 2009.03.25
-- |
-- Module      :  HMM
-- Copyright   :  Copyright (c) 2007--2009 wren ng thornton
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  example
-- Portability :  ?
--
-- This program implements a simple Hidden Markov Model tagger, mostly as an example of use and for benchmarking real-world use under different optimizations. We only use bigrams, so performance isn't going to be that great as a tagger. This is a translation from Perl and couls stand to be cleaned up significantly.
----------------------------------------------------------------


module HMM where
import Prelude hiding (lookup, log, realToFrac, isInfinite, isNaN)

import Data.Number.LogFloat

import Data.Trie
import Data.Trie.Convenience (lookupWithDefault)

import qualified Data.ByteString.Char8 as S

import Control.Monad (liftM, when)
import Control.Arrow ((***))
import Data.List     (foldl', zip4)
import Data.Maybe    (fromJust)

import System.Environment (getArgs, getProgName)
import System.Exit        (exitFailure)
----------------------------------------------------------------

data Counts = Counts
    { countsT,     countsW
    , countsTT,    countsWT
    , singletonTT, singletonWT :: !(Trie Int)
    , tagDict                  :: !(Trie [S.ByteString])
    }
    deriving Show

emptyCounts = Counts empty empty empty empty empty empty empty

applyCountsT     f c = c { countsT     = f (countsT     c) }
applyCountsW     f c = c { countsW     = f (countsW     c) }
applyCountsTT    f c = c { countsTT    = f (countsTT    c) }
applyCountsWT    f c = c { countsWT    = f (countsWT    c) }
applySingletonTT f c = c { singletonTT = f (singletonTT c) }
applySingletonWT f c = c { singletonWT = f (singletonWT c) }
applyTagDict     f c = c { tagDict     = f (tagDict     c) }

-- | Strict pairs
data a :*: b = !a :*: !b
infix 1 :*:

-- stringification of pairs for hashing (bad, bad Perl!)
pair x y = S.intercalate (S.pack "/") [x,y]
unpair   = (id *** S.tail) . S.break ('/'==)
eof      = S.pack "###"


-- Reinventing the state monad for fun and profit
runFor = flip . flip (foldl' . flip)
for    = flip (flip . foldl' . flip)
----------------------------------------------------------------

main :: IO ()
main = do
    argv <- getArgs
    prog <- getProgName
    when (length argv /= 2) $ do
        putStrLn ("Usage: "++prog++" trainingfile testfile")
        exitFailure
    
    viterbi <- do (first:rest) <- S.lines `liftM` S.readFile (head argv)
                  
                  when (first /= eof`pair`eof) $ do
                       putStrLn (prog++": training file doesn't match specifications")
                       exitFailure
                  
                  return
                      . getViterbiProcess
                      . (\(_ :*: c) -> c)
                      . foldl' (\ (prevTag :*: c) (word,tag) ->
                                  (tag :*: updateCounts word tag prevTag c))
                               (eof :*: emptyCounts)
                      . map unpair
                      . filter (not . S.null)
                      $ rest
    
    (realWords,realTags) <- ( unzip
                            . map unpair
                            . filter (not . S.null)
                            . S.lines
                            ) `liftM` S.readFile (last argv)
    
    let (backpointers, perplexity) = viterbi realWords
    when (isInfinite perplexity)
        (putStrLn "* warning: infinite perplexity due to zero probability")
    
    let n    = length realWords - 1
    let tags = reverse . runFor [Just eof] [n,n-1..1] $ \i tags' ->
                   do tag <- head tags'
                      lookupWithDefault Nothing
                          (tag `pair` S.pack (show i)) backpointers 
                   : tags'
    
    print perplexity
    print $ zip realTags tags -- TODO: report accuracy


updateCounts word tag prevTag
    = isSingleton (lookup w_t . countsWT)
          (applyTagDict (adjust (tag :) word)
              . applySingletonWT (increment tag))
          (applySingletonWT $ decrement tag)
    . applyCountsWT (increment w_t)
    
    . isSingleton (lookup t_pt . countsTT)
          (applySingletonTT $ increment prevTag)
          (applySingletonTT $ decrement prevTag)
    . applyCountsTT (increment t_pt)
    
    . applyCountsW (increment word)
    . applyCountsT (increment tag)
    where
    t_pt = tag `pair` prevTag
    w_t  = word `pair` tag
    
    increment = adjust (1+)
    decrement = adjust (subtract 1)
    
    isSingleton p y n c' = case p c' of
                           Just 1 -> y c'
                           Just 2 -> n c'
                           _      -> id c'


getViterbiProcess counts = getViterbiProcess' p_TT p_WT tag_W
    where
    (p_TT, p_WT) = getProbabilityDistributions counts
    
    tag_W = \word -> lookupWithDefault allTags word (tagDict counts)
        where
        allTags = filter (eof /=) . keys $ countsT counts


----------------------------------------------------------------
getProbabilityDistributions :: Counts -> (S.ByteString -> S.ByteString -> Double, S.ByteString -> S.ByteString -> Double)
getProbabilityDistributions counts = (p_TT, p_WT)
    where
    (countVocab :*: countAllWords) =
        -- Extra initial countVocab is for OOV
        runFor (1 :*: 0) (keys $ countsW counts) $ \ word ->
             \ (v :*: aw) -> v  + 1
                         :*: aw + fromJust (lookup word $ countsW counts)
    
    
    p_TT prevTag tag = (mycount + lambda * backoff) / (allcount + lambda)
        where
        mycount  = c_TT (pair tag prevTag)
        allcount = c_T  prevTag
        lambda   = s_TT prevTag
        -- | Unsmoothed @p_T tag@
        backoff  = c_T tag / fromIntegral countAllWords
    
    c_TT key = fromIntegral . lookupWithDefault 0 key $ countsTT counts
    c_T  key = fromIntegral . lookupWithDefault
                   (error $ "zero count for tag: " ++ S.unpack key) key
                   $ countsT counts
    s_TT key = case lookup key (singletonTT counts) of
                   Just n  -> fromIntegral n
                   Nothing -> 1e-100 -- epsilon to avoid p(...)==0 when mycount and lambda are 0
    
    
    p_WT tag word = (mycount + lambda * backoff) / (allcount + lambda)
        where
        mycount  = c_WT (pair word tag)
        allcount = c_T  tag
        lambda   = s_WT tag
        -- | add-one smoothing on @p_W word@
        backoff  = (c_W word + 1) / fromIntegral(countAllWords + countVocab)
        
    c_WT key = fromIntegral . lookupWithDefault 0 key $ countsWT counts
    c_W  key = fromIntegral . lookupWithDefault 0 key $ countsW  counts
    s_WT key = case lookup key (singletonWT counts) of
                   Just n  -> fromIntegral n
                   Nothing -> 1e-100 -- epsilon to avoid p(...)==0


----------------------------------------------------------------
getViterbiProcess' p_TT p_WT tag_W = viterbi
    where
    -- This type signiture is needed to avoid overlapping in the type instances for @perplexity@, iff getProbabilityDistributions is polymorphic.
    viterbi   :: [S.ByteString] -> (Trie (Maybe S.ByteString), Double)
    viterbi ws = (backpointers, perplexity)
        where
        backpointers = fmap (\(_ :*: bp) -> bp) logMuBP
        perplexity   = exp (negate lm / fromIntegral n)
            where
            n  = length ws - 1
            lm = case lookup (eof `pair` S.pack (show n)) logMuBP of
                 Just (lm' :*: _) -> lm'
                 Nothing          -> negativeInfinity -- error "no parse"
        
        -- \mu is Viterbi approximation for \alpha
        logMuBP = 
            let ints = map (S.pack . show) [0..]
            in  runFor (singleton (eof `pair` S.pack "0") (log 1 :*: Nothing))
                (zip4 (tail ws) ws (tail ints) ints) $ \(word,prevWord,i,j) ->
                    for (tag_W word) $ \tag ->
                        for (tag_W prevWord) $ \prevTag ->
                            updatelogMuBP word prevWord i j tag prevTag
    
    
    updatelogMuBP word _ i j tag prevTag lmbp =
        case lookup (prevTag `pair` j) lmbp of
        Just (lmPTJ :*: _)
            | lmPTJ /= negativeInfinity ->
                        alterBy (\_ lmbp'@(lm' :*: _) mlmbp ->
                                    case mlmbp of
                                    Nothing          -> Just lmbp'
                                    Just (lmTI :*: _)
                                        | lmTI < lm' -> Just lmbp'
                                        | otherwise  -> mlmbp
                                )
                                (tag `pair` i)
                                (lmPTJ + log (p_TT prevTag tag * p_WT tag word)
                                    :*: Just prevTag)
                                lmbp
        
        _ -> error "p==0" -- just warning and returning @lmbp@ should suffice

----------------------------------------------------------------
----------------------------------------------------------- fin.
