#!/usr/bin/env perl
# wren ng thornton, hw5, 600.465 Eisner
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

use LogProduct; # local -- slower, but prettier than doing the additions by hand
use warnings;
use strict;

# vars allows the ability to dynamically scope (if strict is turned off), whereas `my` doesn't
use vars qw(%NCountT %NCountW %NCountTT %NCountWT $NCountAllWords); # both original and new, dynamic
my (%CountT, %CountW, %CountTT, %CountWT, $CountAllWords, # these are the currents
	%TagDict, %SeenDict, %SingTT, %SingWT, @AllTags, $CountVocab,
	@TestWords, @RealTestTags, @RawWords);

my $Iterations = 3;

print "Usage: $0 trainingfile testfile rawfile\n\n" and exit 1
	unless 3 == @ARGV and -r $ARGV[0] and -r $ARGV[1] and -r $ARGV[2];

$| = 1; # Autoflush for debugging purposes

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Read in all the files
{
	# Training
	open my $training, "<", $ARGV[0]
		or die "$0: Couldn't open training file: $!\n";
	
	die "$0: training file doesn't match specifications\n"
		unless <$training> eq "###/###\n"; # N.B. this does -1 on $CountAllWords
	
	my $prev_tag = '###';
	while (<$training>) { chomp; next unless $_; my ($word, $tag) = split '/';
		
		$NCountT{$tag}              += 1;
		$NCountW{$word}             += 1;
		
		$NCountTT{"$tag/$prev_tag"} += 1;
		if      ($NCountTT{"$tag/$prev_tag"} == 1) {
			$SingTT{$prev_tag}      += 1;
		} elsif ($NCountTT{"$tag/$prev_tag"} == 2) {
			$SingTT{$prev_tag}      -= 1;
		}
		
		$NCountWT{"$word/$tag"}     += 1;
		if      ($NCountWT{"$word/$tag"} == 1) {
			push @{$TagDict{$word}}, $tag;
			$SingWT{$tag}           += 1;
		} elsif ($NCountWT{"$word/$tag"} == 2) {
			$SingWT{$tag}           -= 1;
		}
		
		$prev_tag = $tag;
	}
	close $training;
	
	
	# Needed to predict for novel words, only calculating so as to not include '###'
	foreach my $tag (keys %NCountT) {
		push @AllTags, $tag
			unless $tag eq '###';
	}
	
	
	# Read in the testing file
	open my $testing, "<", $ARGV[1]
		or die "$0: Couldn't open testing file: $!\n";
	while (<$testing>) { chomp; next unless $_; my ($word, $tag) = split '/';
		push @TestWords,    $word;
		push @RealTestTags, $tag;
	}
	close $testing;


	# Read in the raw file
	open my $raw, "<", $ARGV[2]
		or die "$0: Couldn't open raw file: $!\n";
	while (<$raw>) { chomp; next unless $_;
		push @RawWords, $_;
		$SeenDict{$_} = 1;
	}
	close $raw;
	
	
	# Needed for smoothing, needs to be moved down here so we can include raw vocab
	foreach my $word (keys %NCountW, keys %SeenDict) {
		$NCountAllWords += ($NCountW{$word} || 0);
		$CountVocab     += 1;
	}
	$CountVocab += 1; # for OOV
}


# Using => instead of comma to make the conditional a bit more visually explicit
# N.B. these arguments are in reverse order from p(a|b) because of how => looks
sub ptt($$) { my ($ti_1 => $ti) = @_;
	my $mycount  = $CountTT{"$ti/$ti_1"} || 0;
	my $allcount = $CountT{$ti_1}        || die "$0: count of tag $ti_1 is zero!\n";
	my $lambda   = $SingTT{$ti_1}        || 1e-100; # avoid p==0 when $mycount and $lambda are 0
	my $backoff  = ($CountT{$ti}         || die "$0: count of tag $ti is zero!\n")
					/ $CountAllWords; # unsmoothed pt(ti)
	
	return ($mycount + $lambda * $backoff) / ($allcount + $lambda);
}


# Pulling this out of the inner loop, especially speedy once everything's a LogProduct
# N.B. still calling this "probability of word given tag", even though order of args swapped
sub pwt($$) { my ($ti => $wi) = @_;
	my $mycount  = $CountWT{"$wi/$ti"} || 0;
	my $allcount = $CountT{$ti}        || die "$0: count of tag $ti is zero!\n";
	my $lambda   = $SingWT{$ti}        || 1e-100; # avoid p==0 when $mycount and $lambda are 0
	my $backoff  = (($CountW{$wi} || 0) + 1) / ($CountAllWords + $CountVocab); # add-one pw(wi)
	
	return ($mycount + $lambda * $backoff) / ($allcount + $lambda);
}


# To enable prediction of tags for novel words
sub tag_dict($) { my ($word) = @_;
	if (exists $TagDict{$word}) {
		return @{$TagDict{$word}};
	} else {
		return @AllTags;
	}
}


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Run Viterbi over testing data
sub viterbi() {
	# Copy new (or original) into current
	%CountT  = %NCountT;
	%CountW  = %NCountW;
	%CountTT = %NCountTT;
	%CountWT = %NCountWT;
	$CountAllWords = $NCountAllWords;
	
	my (@tags, $perplexity); {
		my %backpointer;
		my %lmu; $lmu{'###/0'} = LogProduct->new(1); # dunno why perl demands breaking the `my` up
										              # mu is Viterbi approximation for alpha
		foreach my $i (1..$#TestWords) {
			foreach my $ti (tag_dict $TestWords[$i]) {
				my $pwt_ti_wi = pwt($ti => $TestWords[$i]); # Pulling out for speed
				foreach my $ti_1 (tag_dict $TestWords[$i-1]) {
					
					if (my $lmu = $lmu{"$ti_1/".($i-1)}) {
						$lmu *= ptt($ti_1 => $ti) * $pwt_ti_wi;
						
						if (not exists $lmu{"$ti/$i"} or $lmu > $lmu{"$ti/$i"}) {
							$lmu{"$ti/$i"}         = $lmu;
							$backpointer{"$ti/$i"} = $ti_1;
						}
					} else {
						warn "* zero probability for word $TestWords[$i] as tag $ti after tag $ti_1\n";
					}
				}
			}
		}
		
		$perplexity = exp (- $lmu{"###/$#TestWords"}->logvalue() / $#TestWords);
		warn "* Infinite perplexity due to zero probability\n"
			if $perplexity == LogProduct::Inf;
		
		# Generate the actual tag sequence
		$tags[$#TestWords] = '###';
		foreach my $i (reverse 1..$#TestWords) {
			$tags[$i-1] = $backpointer{"$tags[$i]/$i"}; # Buglet: undef in (.) when p==0
		}
	}
	
	
	# Report accuracy et al
	my (%correct, %count);
	foreach my $i (0..$#tags) {
		unless ($RealTestTags[$i] eq '###') {
			my $match = $tags[$i] eq $RealTestTags[$i]; # Buglet: undef in `eq` due to buglet above
			
			$correct{'total'} += 1 if $match;
			$count{'total'}   += 1;
			
			if (exists $TagDict{$TestWords[$i]}) { # `exists` is safe because tag_dict() doesn't autovivify
				$correct{'known'} += 1 if $match;
				$count{'known'}   += 1;
			} elsif (exists $SeenDict{$TestWords[$i]}) {
				$correct{'seen'}  += 1 if $match;
				$count{'seen'}    += 1;
			} else {
				$correct{'novel'} += 1 if $match;
				$count{'novel'}   += 1;
			}
		}
	}
	
	printf "Tagging accuracy: %.2f%% (known: %.2f%% seen: %.2f%% novel: %.2f%%)\n"
		. "Perplexity per tagged test word: %.3f\n",
		($count{'total'} ? 100 * ($correct{'total'} || 0) / $count{'total'} : 0),
		($count{'known'} ? 100 * ($correct{'known'} || 0) / $count{'known'} : 0),
		($count{'seen'}  ? 100 * ($correct{'seen'}  || 0) / $count{'seen'}  : 0),
		($count{'novel'} ? 100 * ($correct{'novel'} || 0) / $count{'novel'} : 0),
		$perplexity;
} # end viterbi()


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Do forward-backward over the raw data and then run viterbi on the new counts
sub forwardbackward($) { my ($iteration) = @_;
	no strict 'vars'; # so we can use dynamic scoping
	# Copy new from original (dynamic scoping only saves us from using another set of names is all)
	# To ignore originals other than for the p()s the first time,
	# comment out the assigning portion leaving the local declaration (e.g. `local $foo;# = $foo;`)
	local %NCountT  = %NCountT;
	local %NCountW  = %NCountW;
	local %NCountTT = %NCountTT;
	local %NCountWT = %NCountWT;
	local $NCountAllWords = $NCountAllWords;
	
	# Do forward-backward using current to get counts into new
	my %lalpha; $lalpha{'###/0'} = LogProduct->new(1);
	foreach my $i (1..$#RawWords) {
		foreach my $ti (tag_dict $RawWords[$i]) {
			my $pwt_ti_wi = pwt($ti => $RawWords[$i]); # Pulling out for speed
			foreach my $ti_1 (tag_dict $RawWords[$i-1]) {
				$lalpha{"$ti/$i"} += $lalpha{"$ti_1/".($i-1)} * ( ptt($ti_1 => $ti) * $pwt_ti_wi );
			}
		}
	}
	my $ls = $lalpha{"###/$#RawWords"}; # total probability of all paths ###/0..###/n
	my %lbeta; $lbeta{"###/$#RawWords"} = LogProduct->new(1);
	foreach my $i (reverse 1..$#RawWords) {
		$NCountAllWords         += 1;
		$NCountW{$RawWords[$i]} += 1;
		
		foreach my $ti (tag_dict $RawWords[$i]) {
			# Pulling counts out of log-space so ptt()/pwt() aren't slowed down dramatically
			# Hopefully counts should be existant enough to not underflow
			my $p_ti_i = $lalpha{"$ti/$i"} * $lbeta{"$ti/$i"} / $ls;
			$p_ti_i    = $p_ti_i->value();
			$NCountT{$ti}                  += $p_ti_i;
			$NCountWT{"$RawWords[$i]/$ti"} += $p_ti_i;
			
			my $pwt_ti_wi = pwt($ti => $RawWords[$i]); # Pulling out for speed
			foreach my $ti_1 (tag_dict $RawWords[$i-1]) {
				
				my $lpbeta = ( ptt($ti_1 => $ti) * $pwt_ti_wi ) * $lbeta{"$ti/$i"};
				
				$lbeta{"$ti_1/".($i-1)} += $lpbeta;
				
				# Pulling counts out of log-space, see note above
				my $count_tt = $lalpha{"$ti_1/".($i-1)} * $lpbeta / $ls;
				$NCountTT{"$ti/$ti_1"}  += $count_tt->value();
			}
		}
	}
	
	printf "Iteration $iteration: Perplexity per untagged raw word: %.3f\n\n",
		exp ( - $ls->logvalue() / $NCountAllWords );
	
	viterbi(); # must be called in here in order to copy the new counts over to current
} # end forwardbackward()


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# The rest of main() after we've read in the files
viterbi();
my $iteration = 0;
while ($iteration <= $Iterations) { # can't count down instead because of printing iteration numbers
	forwardbackward($iteration++);  # need to wait to increment so we start from 0
}
