diff options
author | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 |
---|---|---|
committer | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 |
commit | 5b8253e0e1f1393a509fb9975ba8c1347af758ed (patch) | |
tree | 1790470b1d07a0b4973ebce19192e896566ea60b /training/crf | |
parent | 2389a5a8a43dda87c355579838559515b0428421 (diff) | |
parent | b203f8c5dc8cff1b9c9c2073832b248fcad0765a (diff) |
fixed conflicts
Diffstat (limited to 'training/crf')
-rw-r--r-- | training/crf/Makefile.am | 31 | ||||
-rw-r--r-- | training/crf/baum_welch_example/README.md | 32 | ||||
-rw-r--r-- | training/crf/baum_welch_example/cdec.ini | 5 | ||||
-rwxr-xr-x | training/crf/baum_welch_example/random_init.pl | 9 | ||||
-rw-r--r-- | training/crf/baum_welch_example/tagset.txt | 1 | ||||
-rw-r--r-- | training/crf/baum_welch_example/train.txt | 2000 | ||||
-rw-r--r-- | training/crf/cllh_observer.cc | 52 | ||||
-rw-r--r-- | training/crf/cllh_observer.h | 26 | ||||
-rw-r--r-- | training/crf/mpi_batch_optimize.cc | 372 | ||||
-rw-r--r-- | training/crf/mpi_baum_welch.cc | 316 | ||||
-rw-r--r-- | training/crf/mpi_compute_cllh.cc | 134 | ||||
-rw-r--r-- | training/crf/mpi_extract_features.cc | 151 | ||||
-rw-r--r-- | training/crf/mpi_extract_reachable.cc | 163 | ||||
-rw-r--r-- | training/crf/mpi_flex_optimize.cc | 386 | ||||
-rw-r--r-- | training/crf/mpi_online_optimize.cc | 384 |
15 files changed, 4062 insertions, 0 deletions
diff --git a/training/crf/Makefile.am b/training/crf/Makefile.am new file mode 100644 index 00000000..4a8c30fd --- /dev/null +++ b/training/crf/Makefile.am @@ -0,0 +1,31 @@ +bin_PROGRAMS = \ + mpi_batch_optimize \ + mpi_compute_cllh \ + mpi_extract_features \ + mpi_extract_reachable \ + mpi_flex_optimize \ + mpi_online_optimize \ + mpi_baum_welch + +mpi_baum_welch_SOURCES = mpi_baum_welch.cc +mpi_baum_welch_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz + +mpi_online_optimize_SOURCES = mpi_online_optimize.cc +mpi_online_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz + +mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc +mpi_flex_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz + +mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc +mpi_extract_reachable_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz + +mpi_extract_features_SOURCES = mpi_extract_features.cc +mpi_extract_features_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz + +mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc cllh_observer.h +mpi_batch_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz + +mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc cllh_observer.h +mpi_compute_cllh_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz + +AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I$(top_srcdir)/training -I$(top_srcdir)/training/utils -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/crf/baum_welch_example/README.md b/training/crf/baum_welch_example/README.md new file mode 100644 index 00000000..97525da5 --- /dev/null +++ b/training/crf/baum_welch_example/README.md @@ -0,0 +1,32 @@ +Here's how to do Baum-Welch training with `cdec`. + +## Set the tags you want. + +First, set the number of tags you want in tagset.txt (these +can be any symbols, listed one after another, separated +by whitespace), e.g.: + + C1 C2 C3 C4 + +## Extract the parameter feature names + + ../mpi_extract_features -c cdec.ini -t train.txt + +If you have compiled with MPI, you can use `mpirun`: + + mpirun -np 8 ../mpi_extract_features -c cdec.ini -t train.txt + +## Randomly initialize the weights file + + sort -u features.* | ./random_init.pl > weights.init + +## Run training + + ../mpi_baum_welch -c cdec.ini -t train.txt -w weights.init -n 50 + +Again, if you have compiled with MPI, you can use `mpirun`: + + mpirun -np 8 ../mpi_baum_welch -c cdec.ini -t train.txt -w weights.init -n 50 + +The `-n` flag indicates how many iterations to run for. + diff --git a/training/crf/baum_welch_example/cdec.ini b/training/crf/baum_welch_example/cdec.ini new file mode 100644 index 00000000..61203da7 --- /dev/null +++ b/training/crf/baum_welch_example/cdec.ini @@ -0,0 +1,5 @@ +feature_function=Tagger_BigramIndicator +feature_function=LexicalPairIndicator +formalism=tagger +tagger_tagset=tagset.txt +intersection_strategy=full diff --git a/training/crf/baum_welch_example/random_init.pl b/training/crf/baum_welch_example/random_init.pl new file mode 100755 index 00000000..98467ed1 --- /dev/null +++ b/training/crf/baum_welch_example/random_init.pl @@ -0,0 +1,9 @@ +#!/usr/bin/perl -w +while(<>) { + chomp; + my ($a,$b,@d) =split /\s+/; + die "Bad input" if scalar @d > 0; + $r = -rand() * rand() - 0.5; + $r = 0 if $a =~ /^Uni:/; + print "$a $r\n"; +} diff --git a/training/crf/baum_welch_example/tagset.txt b/training/crf/baum_welch_example/tagset.txt new file mode 100644 index 00000000..93a48451 --- /dev/null +++ b/training/crf/baum_welch_example/tagset.txt @@ -0,0 +1 @@ +1 2 3 4 diff --git a/training/crf/baum_welch_example/train.txt b/training/crf/baum_welch_example/train.txt new file mode 100644 index 00000000..e9c3455e --- /dev/null +++ b/training/crf/baum_welch_example/train.txt @@ -0,0 +1,2000 @@ +t h e +t o +o f +i n +a n d +a +s a i d +f o r +o n +t h a t +w i t h +w a s +i s +b y +a t +h e +a s +f r o m +i t +h a s +b e +h i s +h a v e +w i l l +a n +a r e +w e r e +b u t +n o t +w h o +a f t e r +h a d +y e a r +i t s +t w o +t h i s +w h i c h +t h e y +t h e i r +g o v e r n m e n t +b e e n +w e +p e r c e n t +w o u l d +n e w +i +a l s o +u p +m o r e +o n e +p e o p l e +f i r s t +l a s t +a b o u t +c h i n a +p r e s i d e n t +o v e r +m i l l i o n +o r +o u t +w o r l d +w h e n +a l l +o t h e r +m i n i s t e r +t h r e e +t h a n +u n i t e d +t h e r e +a g a i n s t +i n t o +c o u n t r y +s o m e +p o l i c e +n o +t i m e +y e a r s +s t a t e +w e d n e s d a y +t u e s d a y +t h u r s d a y +s t a t e s +m o n d a y +u s +c o u l d +i f +f r i d a y +s i n c e +b i l l i o n +s h e +f o r e i g n +o f f i c i a l s +d a y +i n t e r n a t i o n a l +h e r +b e t w e e n +o n l y +b e f o r e +s o u t h +w h i l e +d u r i n g +n a t i o n a l +t o l d +s e c o n d +g r o u p +f o u r +d o w n +c i t y +p a r t y +t h e m +s e c u r i t y +d o +m a d e +d o l l a r s +p o i n t s +u n d e r +m i l i t a r y +b e c a u s e +w e e k +c o u n t r i e s +c a n +c h i n e s e +o f f +s u n d a y +m o s t +s o +h i m +e c o n o m i c +f o r m e r +i r a q +f i v e +s a t u r d a y +a c c o r d i n g +d i d +n o w +o f f i c i a l +m a y +n e w s +w a r +a n y +w h e r e +t e a m +m e e t i n g +k i l l e d +b a n k +s h o u l d +j u s t +r e p o r t e d +m a n y +n e x t +w h a t +c o m p a n y +i n c l u d i n g +b a c k +m o n t h +r e p o r t +o u r +p r i m e +m a r k e t +s t i l l +b e i n g +c o u r t +t r a d e +h e r e +p e a c e +h i g h +o l d +s e t +t h r o u g h +y o u +i s r a e l +t a l k s +e n d +t a k e +e x p e c t e d +p o l i t i c a l +s i x +s u c h +b o t h +m a k e +h o m e +l o c a l +j a p a n +r u s s i a +s a y i n g +g e n e r a l +t o p +a n o t h e r +e u r o p e a n +n o r t h +h e l d +t h i r d +m a j o r +s t a t e m e n t +w e l l +a m e r i c a n +i s r a e l i +t a i w a n +l e a d e r +c a p i t a l +l o n g +o i l +t h o s e +c a l l e d +p a r t +s p o k e s m a n +w o r k +d e v e l o p m e n t +a d d e d +s a y s +w o n +m e m b e r s +l e f t +c h i e f +g a m e +l i k e +t h e n +h e l p +s a y +p a l e s t i n i a n +v e r y +c u p +p u b l i c +f r a n c e +c e n t r a l +l e a d e r s +w i n +b u s h +m i n i s t r y +m o n t h s +g e t +w a y +d a y s +r e g i o n +s u p p o r t +t r o o p s +a g e n c y +f o r c e s +e a r l i e r +e v e n +n a t i o n s +v i s i t +g a m e s +e u +f i n a l +a m o n g +h o u s e +s e v e r a l +e a r l y +l e d +d l r s +l a t e r +w o m e n +k o n g +h o n g +p r e s s +p o w e r +t o d a y +o p e n +i n d e x +o f f i c e +f o l l o w i n g +a r o u n d +b a s e d +c o n f e r e n c e +b r i t i s h +c o u n c i l +u n i o n +t o o k +c a m e +w e s t +r u n +h o w e v e r +e a s t +l a t e +s e a s o n +g o o d +c l o s e +g e r m a n y +l e a d +p a s t +d e f e n s e +p l a c e +n u m b e r +a r m y +r u s s i a n +l a w +i n d i a +m e n +f i n a n c i a l +e c o n o m y +l e a s t +s e c r e t a r y +s a m e +y o r k +f o u n d +g o i n g +r i g h t +g o +m y +o p p o s i t i o n +f o r c e +a g r e e m e n t +e l e c t i o n +h o w +b u s i n e s s +f r e n c h +a u t h o r i t i e s +p l a y +m u c h +r i g h t s +t i m e s +c o m m i t t e e +r o u n d +p r o v i n c e +k o r e a +h a l f +a t t a c k +p r i c e s +s t o c k +h i t +p l a n +a r e a +c o o p e r a t i o n +s e v e n +n e a r +e x c h a n g e +u s e d +n u c l e a r +p a k i s t a n +b e i j i n g +a n n o u n c e d +a i r +a f r i c a +c e n t e r +a g o +t h e s e +d e c i s i o n +a t t a c k s +w i t h o u t +m a t c h +m a r c h +n a t i o n +h e a d +t o t a l +c o m p a n i e s +m a n +d e a l +w a s h i n g t o n +r e c e n t +c a s e +f i r e +n i g h t +a u s t r a l i a +a f r i c a n +u n t i l +i r a n +e l e c t i o n s +s o u t h e r n +l e a g u e +p u t +e a c h +m e m b e r +c h i l d r e n +h e a l t h +p c +p a r l i a m e n t +l o s t +t h i n k +d e a t h +m u s t +e i g h t +w o r k e r s +u s e +b r i t a i n +w a n t +s y s t e m +r e c o r d +d e p a r t m e n t +p r o g r a m +e f f o r t s +g r o w t h +r e s u l t s +i r a q i +i s s u e +b e s t +w h e t h e r +h u m a n +n o r t h e r n +c o n t r o l +f a r +f u r t h e r +a l r e a d y +s h a r e s +r e l a t i o n s +m e e t +s o l d i e r s +s e e +f r e e +c o m e +j a p a n e s e +m o n e y +d o l l a r +r e p o r t s +d i r e c t o r +s h a r e +g i v e +j u n e +c o m m i s s i o n +l a r g e s t +i n d u s t r y +c o n t i n u e +s t a r t +c a m p a i g n +l e a d i n g +q u a r t e r +i n f o r m a t i o n +v i c t o r y +r o s e +o t h e r s +a n t i +r e t u r n +f a m i l y +i s s u e s +s h o t +p o l i c y +e u r o p e +m e d i a +p l a n s +b o r d e r +n e e d +p e r +a r e a s +j u l y +v i o l e n c e +d e s p i t e +d o e s +s t r o n g +c h a i r m a n +s e r v i c e +v o t e +a s k e d +f o o d +f e d e r a l +w e n t +t a k e n +d u e +m o v e +o w n +b e g a n +i t a l y +g r o u p s +p o s s i b l e +l i f e +c l e a r +r a t e +f e l l +c r i s i s +r e b e l s +p o i n t +d e m o c r a t i c +w a t e r +a h e a d +i n v e s t m e n t +i n c r e a s e +s h o w +p l a y e r s +r e l e a s e d +g e r m a n +t o w n +n e a r l y +p r o c e s s +m i l e s +f e w +d i e d +a d m i n i s t r a t i o n +a p r i l +w e e k s +l e v e l +k e y +a s i a +s a l e s +w e a p o n s +c l o s e d +b e h i n d +m i n u t e s +a g r e e d +p r e s i d e n t i a l +g r e a t +m a k i n g +t o o +p r i c e +j o h n +g o t +f o u r t h +a g a i n +k i l o m e t e r s +s i t u a t i o n +m a i n +w h i t e +r e p o r t e r s +h o u r s +s e n i o r +a s i a n +r e p u b l i c +a w a y +g l o b a l +f i g h t i n g +s e r i e s +b e t t e r +n e w s p a p e r +m e +c l i n t o n +a r r e s t e d +h i g h e r +k n o w +f u t u r e +s c o r e d +g o l d +n a t o +m o r n i n g +n e v e r +b e a t +w i t h i n +r n +c u r r e n t +p e r i o d +b e c o m e +w e s t e r n +i m p o r t a n t +l o n d o n +a u s t r a l i a n +s p a i n +e n e r g y +a i d +a c c u s e d +o l y m p i c +n i n e +a c r o s s +f a c e +o r g a n i z a t i o n +g o a l +t a k i n g +i n j u r e d +s p e c i a l +r a c e +d a i l y +s u m m i t +s i d e +l i n e +o r d e r +u n i v e r s i t y +a f g h a n i s t a n +p l a y e d +b i g +c a r +t r y i n g +e n g l a n d +q u o t e d +d e +a l o n g +i s l a m i c +o u t s i d e +t r a d i n g +e d s +c u t +a c t i o n +p r o b l e m s +v i c e +w o r k i n g +y e n +b u i l d i n g +s i g n e d +k n o w n +c h a n g e +c h a r g e s +s m a l l +l o w e r +a l t h o u g h +s e n t +c o n g r e s s +h o s p i t a l +h o l d +m i g h t +u n +e v e r y +g i v e n +d e p u t y +i n t e r e s t +i s l a n d +s c h o o l +d r u g +k i l l i n g +r u l i n g +t o u r +o p e n i n g +t e r m +f u l l +c l r +l i t t l e +m a r k e t s +c o a c h +j a n u a r y +s c h e d u l e d +k e e p +t u r k e y +p r e v i o u s +e x e c u t i v e +g a s +m e t +j o i n t +t r i a l +b o a r d +p r o d u c t i o n +i n d o n e s i a +s e r v i c e s +l i k e l y +t h o u s a n d s +i n d i a n +p o s t +a r a b +c e n t s +h o p e +s i n g a p o r e +p a l e s t i n i a n s +p a r t i e s +g a v e +b i l l +d e a d +r o l e +s e p t e m b e r +t e l e v i s i o n +c o m m u n i t y +r e g i o n a l +a d d i n g +a m e r i c a +o n c e +y u a n +t e s t +s t o c k s +w h o s e +p a y +p r i v a t e +l a t e s t +i n v e s t o r s +f r o n t +c a n a d a +r e l e a s e +r e c e i v e d +m e a n w h i l e +l e s s +t h a i l a n d +l a n d +c h a m p i o n +r e a c h e d +u r g e d +d e c e m b e r +a s s o c i a t i o n +f i g h t +s i d e s +s t a r t e d +l a r g e +y e t +m i d d l e +c a l l +p r e s s u r e +e n d e d +s o c i a l +p r o j e c t +l o w +h a r d +c l u b +p r e m i e r +t e c h n o l o g y +f a i l e d +t o u r n a m e n t +r e a l +p r o v i d e +g a z a +m i n u t e +a f f a i r s +m i n i s t e r s +p r o d u c t s +r e s e a r c h +s e e n +g e o r g e +e v e n t +s t o p +i n v e s t i g a t i o n +a i r p o r t +m e x i c o +t i t l e +t o k y o +e a s t e r n +b i g g e s t +y o u n g +d e m a n d +t h o u g h +a r m e d +s a n +o p e n e d +m e a s u r e s +n o v e m b e r +a v e r a g e +m a r k +o c t o b e r +k o r e a n +r a d i o +b o d y +s e c t o r +c a b i n e t +g m t +a s s o c i a t e d +a p +c i v i l +t e r r o r i s m +s h o w e d +p r i s o n +s i t e +p r o b l e m +s e s s i o n +b r a z i l +m u s l i m +c o a l i t i o n +b a g h d a d +b i d +s t r e e t +c o m i n g +b e l i e v e +m a l a y s i a +s t u d e n t s +d e c i d e d +f i e l d +r e d +n e g o t i a t i o n s +w i n n i n g +o p e r a t i o n +c r o s s +s o o n +p l a n n e d +a b l e +t i e s +t a x +j u s t i c e +d o m e s t i c +d a v i d +i n c l u d e +n a m e +b o m b +t r a i n i n g +j u d g e +v i c t i m s +m e d i c a l +c o n d i t i o n +f i n d +r e m a i n +i s s u e d +f i n a n c e +l o t +l a b o r +b t +e n o u g h +i m m e d i a t e l y +s h o r t +l o s s +a n n u a l +m o v e d +r e b e l +s t r i k e +r o a d +r e c e n t l y +i t a l i a n +c o n s t r u c t i o n +t r y +a u g u s t +e x p r e s s e d +m i l i t a n t s +t o g e t h e r +w a n t e d +r a t e s +f u n d +f o r w a r d +m i s s i o n +d i s c u s s +r e s u l t +c a l l s +k o s o v o +o p e r a t i o n s +c a s e s +z e a l a n d +s o u r c e s +i n c r e a s e d +l e g a l +b a n k s +i n v o l v e d +o f f i c e r s +l e a v e +m e t e r s +w a r n e d +h a v i n g +r e a c h +b r i n g +h i s t o r y +d i s t r i c t +j o b +a l l o w e d +a r r i v e d +t o w a r d +c l a i m e d +e g y p t +t e a m s +a l l o w +a l m o s t +f e b r u a r y +s e r i o u s +p o o r +c o n t i n u e d +s t e p +i n t e r v i e w +e d u c a t i o n +n o n +r e a l l y +s t a r +l e e +r e s i d e n t s +b a n +s o c c e r +n e e d e d +p a r i s +i n d u s t r i a l +p l a y e r +m o s c o w +s t a t i o n +o f f e r +h u n d r e d s +t a l i b a n +w o m a n +m a n a g e m e n t +l e b a n o n +n o t e d +c h e n +p o s i t i o n +f i n i s h e d +c o s t +e x p e r t s +e v e r +m o v e m e n t +t e r r o r i s t +p l a n e +b l a c k +d i f f e r e n t +b e l i e v e d +p l a y i n g +c a u s e d +h o p e s +c o n d i t i o n s +b r o u g h t +f o r c e d +l a u n c h e d +w e e k e n d +m i c h a e l +s e a +r i s e +d e t a i l s +s p o r t s +e t h n i c +s t a f f +c h a n c e +g o a l s +b u d g e t +h a n d +b a s e +s e c o n d s +s r i +s p e a k i n g +o f f i c e r +m a j o r i t y +w a n t s +c h a r g e d +s h a n g h a i +v i e t n a m +x i n h u a +c o m m e n t +d r o p p e d +t u r n e d +p r o t e s t +r e f o r m +s u s p e c t e d +a m i d +t r i e d +c i t i e s +g r o u n d +t u r k i s h +s t a g e +e f f o r t +s +c o m m u n i s t +a n a l y s t s +h a m a s +p r o j e c t s +c o n t r a c t +i n d e p e n d e n c e +l o o k i n g +a m +s i g n +f o l l o w e d +r e m a i n s +c o m p a r e d +u s i n g +h e a v y +a f t e r n o o n +s t r a i g h t +l o o k +f a l l +r e a d y +e u r o +c h a r g e +w o u n d e d +p r o g r e s s +p a c i f i c +d e n i e d +h o u r +c a r e e r +c o n f i r m e d +t h a i +r u l e +c o u r s e +w i f e +e x p o r t s +b e c a m e +a m e r i c a n s +e m e r g e n c y +a r a f a t +r e f u s e d +l i s t +a l l e g e d +c h a m p i o n s h i p +p o p u l a t i o n +n e e d s +c o m p e t i t i o n +o r d e r e d +s a f e t y +a u t h o r i t y +i l l e g a l +t v +d o n e +e v i d e n c e +s t a y +f i f t h +s e e k i n g +s t u d y +l i v e +r u n s +c o a s t +s a u d i +h e l p e d +a c t i v i t i e s +m a n a g e r +w o r t h +k i n g +g r o w i n g +r u n n i n g +f i r e d +i n c l u d e d +p a u l +w a l l +r e t u r n e d +c o n f l i c t +m y a n m a r +d e m o c r a c y +p r o +f o r m +a l w a y s +a m b a s s a d o r +m a t c h e s +t h i n g s +m a i n l a n d +s a w +d i s e a s e +r e l a t e d +f u n d s +i n d e p e n d e n t +t o n s +a p p r o v e d +e m b a s s y +c u r r e n c y +b r e a k +s e n a t e +c o n c e r n s +f i g u r e s +j o i n +r e s o l u t i o n +o f t e n +c o n f i d e n c e +e s p e c i a l l y +w i n n e r +c a r r i e d +i m p r o v e +s w e d e n +z i m b a b w e +t h r e a t +c u r r e n t l y +s i n g l e +h i m s e l f +l i v i n g +r e f u g e e s +a i m e d +c o u n t y +c a n n o t +a r m s +b u i l d +g e t t i n g +a p p e a r e d +d i f f i c u l t +s p a n i s h +r i v e r +m i s s i n g +e s t i m a t e d +s o m e t h i n g +p r o p o s e d +c e r e m o n y +i n s t e a d +b r o k e +c h u r c h +o l y m p i c s +s p a c e +p r o f i t +v i l l a g e +l i g h t +p e r f o r m a n c e +d e l e g a t i o n +t r i p +o v e r a l l +p a r t s +a c t +c o r r u p t i o n +d i v i s i o n +s i m i l a r +p o s i t i v e +c a m p +g r a n d +p o r t +s u p p o r t e r s +r e p u b l i c a n +b e g i n +j o n e s +p a r k +b i l a t e r a l +c l o u d y +d i p l o m a t i c +p r e s e n t +l o s +a r g e n t i n a +t r a v e l +s p e e c h +a t t e n t i o n +n e t +j o b s +a r r e s t +p r o s e c u t o r s +i n f l a t i o n +n a m e d +j o r d a n +s o n +g o v e r n m e n t s +r u l e s +p r o t e c t i o n +k e n y a +h o m e s +l i v e s +s e r b +s a n c t i o n s +a t t e m p t +e x p o r t +m e a n s +n i g e r i a +r e m a i n e d +t u r n +c r i m e s +c o n c e r n +e n v i r o n m e n t +p l a n t +l e t t e r +v a l u e +r e s p o n s e +a s s e m b l y +p r o p o s a l +h o l d i n g +b o m b i n g +e n s u r e +a f g h a n +r e s o u r c e s +f a m i l i e s +r e s t +i n s i d e +t h r o u g h o u t +m a t t e r +c a u s e +l a w m a k e r s +i i +f u e l +c a l i f o r n i a +e g y p t i a n +o w n e d +s u i c i d e +c z e c h +c a r e +a t t o r n e y +c l a i m s +v o t e r s +n e t w o r k +b a l l +p h i l i p p i n e +f o o t b a l l +s p o k e s w o m a n +i n c i d e n t +p r e v e n t +w h y +d e v e l o p i n g +c i v i l i a n s +e n g l i s h +o b a m a +i n t e r n e t +r i c e +s a d d a m +y o u r +u p d a t e s +l e t +d o i n g +a i r c r a f t +f l i g h t +a n g e l e s +i n t e l l i g e n c e +p h i l i p p i n e s +f a t h e r +c r e d i t +a l l i a n c e +t e r m s +r a i s e d +i r a n i a n +c h a n g e s +s y r i a +v a r i o u s +i n d o n e s i a n +l i +i r e l a n d +l e a v i n g +d e c l i n e d +c o m m o n +i n j u r y +t r e a t m e n t +a v a i l a b l e +c h a m p i o n s +e l e c t e d +s u m m e r +d a t a +o v e r s e a s +p a i d +c e n t u r y +n o t h i n g +f i r m +r e l i g i o u s +s w i t z e r l a n d +o f f e r e d +c h a m p i o n s h i p s +t h o u g h t +c a n d i d a t e +c o n s i d e r e d +r i s k +c r i m e +g o v e r n o r +f i l m +r a l l y +f l o r i d a +t e r r o r +d o u b l e +e q u i p m e n t +j e r u s a l e m +c a r r y i n g +p e r s o n +f e e l +t e r r i t o r y +a l +c o m m e r c i a l +u k r a i n e +b o d i e s +p r o t e s t s +n e t h e r l a n d s +f i n i s h +a c c e s s +t a r g e t +a u s t r i a +s o u r c e +r e p r e s e n t a t i v e s +s p e n t +j e w i s h +p o t e n t i a l +r i s i n g +t r e a t y +c a n a d i a n +a g e +c a +s p e n d i n g +n e c e s s a r y +r a i n +z o n e +c a r s +p r o m o t e +n a t u r a l +d a m a g e +f o c u s +w e a t h e r +p o l i c i e s +p r o t e c t +a i d s +c o +g i v i n g +b c +b a c k e d +l a n k a +a p p e a l +r e j e c t e d +f a n s +b a d +s o u t h e a s t +r i v a l +p l a n n i n g +b o s n i a +c o m e s +b u y +s o v i e t +h o t e l +d u t c h +q u e s t i o n +t a i p e i +b o o s t +c o s t s +i n s t i t u t e +s o c i e t y +s h o o t i n g +t h e m s e l v e s +e v e n t s +k i n d +p a p e r +w o r k e d +c o n s t i t u t i o n +u r g e n t +s e t t l e m e n t +e a r n i n g s +j o s e +m o t h e r +a c c i d e n t +f a c t +d r o p +r a n g e +h a n d s +s e e k +h u g e +l a w y e r +s t a r t i n g +h e a r t +c o m m a n d e r +t o u r i s m +p a s s e n g e r s +s u s p e c t s +h i g h e s t +p o p u l a r +s t a b i l i t y +s u p r e m e +b u s +r o b e r t +b a t t l e +p r o g r a m s +c u b a +w i n s +d r u g s +s u r v e y +h o s t +m u r d e r +d a t e +g u l f +w i l l i a m s +s e n d +s u f f e r e d +p e n a l t y +k e p t +s t a d i u m +c i t i z e n s +f i g u r e +h e a d q u a r t e r s +g u a r d +p u b l i s h e d +s t a n d +t e n n i s +c r e a t e +b e g i n n i n g +e v e n i n g +p h o n e +f o o t +r u l e d +c a s h +s o l d +c h i c a g o +p o l a n d +d e m o c r a t s +r e f o r m s +b o s n i a n +s u r e +c h i l d +m a y o r +a t t e n d +l e a d e r s h i p +e m p l o y e e s +t e l e p h o n e +l o s s e s +b o r n +a s s i s t a n c e +t h i n g +t r a i n +s u p p l y +e i t h e r +b u i l t +l a u n c h +c r u d e +m o v i n g +g r e e c e +t r a c k +r a i s e +d r i v e +r e s p o n s i b i l i t y +f e d e r a t i o n +c o l o m b i a +g r e e n +c o n c e r n e d +c a n d i d a t e s +n e w s p a p e r s +r e v i e w +i n t e r i o r +d e b t +w h o l e +t e x a s +m o s t l y +r e l i e f +f a r m e r s +g o o d s +p a k i s t a n i +d e g r e e s +s e l l +d e t a i n e d +s w i s s +c r i m i n a l +d e c a d e s +m i s s i l e +a b o v e +d r a w +p a s s e d +e x p l o s i o n +m a k e s +l a w s +b a n g l a d e s h +t a l k +m a d r i d +m a s s +c o n v i c t e d +i t e m s +m e d a l +s u c c e s s +s e a t s +q u i c k l y +c a l l i n g +k i m +t r a f f i c +d i r e c t +o r g a n i z a t i o n s +l e v e l s +s e r v e +a d d r e s s +s t r e s s e d +s t a n d i n g +w a n g +d e c l a r e d +j a m e s +c a p t a i n +t h r e a t e n e d +p r o m i s e d +s u d a n +v a n +p a s s +e n v i r o n m e n t a l +r a t h e r +w o r s t +p o u n d s +b l u e +s i x t h +m e t e r +i n c l u d e s +m u s i c +r e d u c e +t a k e s +v o t e s +r e s c u e +c o m p l e t e d +s e a r c h +i n n i n g s +v e h i c l e s +c l a i m +t r a n s p o r t +a v o i d +i n c o m e +p o l l +a f f e c t e d +g e o r g i a +g a i n e d +w o +r e +v i s i t i n g +r e s p o n s i b l e +e f f e c t +p o l l s +h e a r i n g +l o s i n g +e s t a b l i s h e d +f a i r +g i a n t +c h a l l e n g e +f e e t +p r o p e r t y +t e s t s +l e g +a g r i c u l t u r e +l o n g e r +d e a t h s +s q u a r e +p a r t i c u l a r l y +d i s p u t e +b +e n t e r p r i s e s +v o l u m e +c a r r y +m i d +s e p a r a t e +i d e n t i f i e d +i t s e l f +h e a d e d +a n o n y m i t y +p a r l i a m e n t a r y +c r a s h +r e m a i n i n g +j o u r n a l i s t s +i n c r e a s i n g +s t a t i s t i c s +d e s c r i b e d +b u r e a u +i n j u r i e s +p r o v i d e d +j o i n e d +i m m e d i a t e +d e b a t e +i m p a c t +m e s s a g e +m e e t i n g s +r e q u e s t +s c h o o l s +o c c u r r e d +r e m a r k s +c o m m i t t e d +p r o t e s t e r s +t o u g h +s p o k e +s t r i p +f a c e s +c r o w d +s h o w s +w a r n i n g +s t o r y +q u a l i t y +p e t e r +f r e e d o m +d e v e l o p +m a r t i n +p e r s o n a l +s e r b i a +a n y t h i n g +b l a m e d +i n t e r e s t s +n e i g h b o r i n g +d o c t o r s +f l i g h t s +s h i p +r e g i m e +b l a i r +u n i t +a g e n c i e s +a f p +s u g g e s t e d +l a c k +s e l l i n g +a n n a n +y u g o s l a v i a +l a +c o n s u m e r +s u s p e n d e d +s t o p p e d +c o m m e n t s +c o m p u t e r +c o n s i d e r +a i r l i n e s +l e b a n e s e +p r e p a r e d +d i a l o g u e +e x p e c t +t w i c e +p u t i n +a l l e g a t i o n s +b r o w n +a c c e p t +a p p r o v a l +w i d e +n e a r b y +s y s t e m s +v i e w +p u s h +p r o b a b l y +e v e r y t h i n g +d r a f t +t r a d i t i o n a l +s t a t u s +s t r u c k +s e i z e d +p a r t l y +s t a n d a r d +h u s s e i n +p o v e r t y +d o z e n s +r e g i o n s +c r i c k e t +l o a n s +e +b o o k +b a s i s +a n n o u n c e m e n t +r u r a l +s e r b s +a d d i t i o n +g r e e k +c o m p l e t e +r o o m +g r e a t e r +a l l e g e d l y +f i n a l s +f a c i n g +l i m i t e d +c u t s +r i c h a r d +b u s i n e s s e s +l i n k e d +p e a c e f u l +c r e w +t o u r i s t s +m a i n l y +p r i s o n e r s +p o w e r f u l +c r o a t i a +f i l e d +k u w a i t +f o r u m +r e s e r v e +m i l a n +b l a s t +a n n i v e r s a r y +a t t e n d e d +e n d i n g +d e v e l o p e d +c e r t a i n +b e l o w +f e l t +p r o v i n c i a l +c y p r u s +c r i t i c i z e d +o p p o r t u n i t y +s m i t h +p o l i t i c s +s e l f +h u m a n i t a r i a n +r e a s o n +l a w y e r s +r e v e n u e +d o c u m e n t s +w r o t e +q u e s t i o n s +n o r w a y +d o w +p a n e l +f e a r +s e n t e n c e d +b a n n e d +c i v i l i a n +c u l t u r a l +p e r s o n n e l +b e l g i u m +a b u +c a p a c i t y +a m o u n t +s e c u r i t i e s +b l o o d +s i g n i f i c a n t +e x p e r i e n c e +a s e a n +h o u s i n g +j o h n s o n +p h o t o s +r o y a l +i m p o r t s +a d d i t i o n a l +y e l t s i n +c d y +h e a r d +t h o m a s +b a n k i n g +l e a d s +v i s i t e d +f e a r s +u g a n d a +d r i v e r +c o n t r o l l e d +d e m a n d s +i n s t i t u t i o n s +a l i +c h r i s t i a n +s t o r m +f o r e c a s t +g r a f +f i g h t e r s +s t r e e t s +r e s p e c t +s p o t +w e b +m i s s e d +s c i e n c e +h e a d s +h i t s +m a s s i v e +c u l t u r e +c o u p l e +v e n e z u e l a +r e p o r t e d l y +i n s u r a n c e +s p r e a d +s o l u t i o n +p l a c e d +s e r v e d +f a c i l i t i e s +s t r a t e g y +t e c h n i c a l +s t e p s +d e e p +h o p e d +d e c i d e +s a l e +j a i l +d i s c u s s e d +s a v e +n e p a l +a r a b i a +e n v o y +a t t a c k e d +w a y s +r e c e i v e +h a p p y +h a l l +g u i l t y +p r a c t i c e +l o v e +e u r o s +o p e r a t i n g +c h a n g e d +b o s t o n +d e c a d e +d e f i c i t +p r o d u c t +l i n e s +p a t i e n t s +f r i e n d s +s y d n e y +a c c o r d +t i e d +s p e e d +w o r d s +t i e +s c o r e +c o n d u c t e d +c r i t i c i s m +m u s l i m s +b r o t h e r +c l a s s +r o m a n i a +h e l p i n g +f a s t +h a p p e n e d +d e f e n d i n g +n a v y +w i t n e s s e s +f u l l y +s u s p e c t +i s l a n d s +m a i n t a i n +p r e s e n c e +j a k a r t a +p a c k a g e +y a r d s +g a i n +a c c o u n t +s q u a d +s h a r o n +w i n g +a c t i o n s +a t h e n s +s t r a t e g i c +s t r e n g t h e n +f r i e n d l y +d e s t r o y e d +a p p a r e n t l y +c o n s e r v a t i v e +g a i n s +f a i l u r e +f u t u r e s +s h o t s +r e l a t i o n s h i p +c o m m i s s i o n e r +m a l a y s i a n +r e q u i r e d +a t l a n t a +a g r e e +d e f e a t +s t r i k e r +a d v a n c e d +b r a z i l i a n +a s s e t s +h o u s e s +s u p p l i e s +s a f e +m i l l i o n s +s o u g h t +f r e s h +v i d e o +p r o s e c u t o r +p u l l e d +v e h i c l e +t o l l +p a r e n t s +c e a s e +a c t i v i s t s +o r g a n i z e d +e n t e r e d +s h i i t e +l a n g u a g e +a b b a s +b i n +p r e v i o u s l y +c l o s i n g +w o r k s +t e r r o r i s t s +t o n y +c o v e r +f o l l o w +l e g i s l a t i v e +r i c h +c l a s h e s +i m p o s e d +r a n +m c c a i n +s u c c e s s f u l +s e v e n t h +s c o r i n g +c a u g h t +a p p o i n t e d +a l l i e s +a d m i t t e d +w o r l d w i d e +o r d e r s +d e m a n d e d +c r e a t e d +r a n k e d +m i l i t a n t +i n v e s t i g a t o r s +s h o w i n g +p o s s i b i l i t y +s e a t +d a u g h t e r +s i t e s +s h o r t l y +c o m m e r c e +n e t a n y a h u +a d v a n c e +a i r l i n e +f i r m s +a b r o a d +f o u n d a t i o n +c o m m i t m e n t +p l e d g e d +k i l l +r e p r e s e n t a t i v e +n o r t h w e s t +s c e n e +b e a t i n g +i m p r o v e d +r e s u m e +w h o m +s l i g h t l y +v o t i n g +b o m b i n g s +s e r i o u s l y +s e t t i n g +c a r l o s +e f f e c t i v e +h k +r e g u l a r +j i a n g +p r i n c e +d e c l i n e +b a y +n o r t h e a s t +s o l d i e r +r e a c h i n g +a g r e e m e n t s +m i k e +h u r t +c r i t i c a l +i d e a +m i l o s e v i c +f i s c a l +t a r g e t s +a g r i c u l t u r a l +m u s h a r r a f +d e s i g n e d +o v e r n i g h t +b o y +d o z e n +p r o d u c e +c a l m +s t a n d a r d s +l e g i s l a t i o n +s e n t e n c e +w i t h d r a w a l +s e e d e d +c o m p o s i t e +t r a d e d +w i n t e r +d a v i s +t r u s t +c l i m a t e +i n d u s t r i e s +p r o f i t s +v o t e d +c a m b o d i a +s y r i a n +s i g n s +l o a n +s t e e l +e l e c t r i c i t y +t e h r a n +c i t i n g +h u s b a n d +b i t +c o m b a t +h a n d e d +f e s t i v a l +i m f +p r e s i d e n c y +c a p t u r e d +s t u d e n t +f i n e +s t a t i o n s +s i l v e r +c h a v e z +i n t e r +m o m e n t +t a b l e +c o u p +p o p e +p r o v i n c e s +a h m e d +b u i l d i n g s +o u t p u t +l i b e r a t i o n +m o n e t a r y +c l o s e r +c o l l e g e +f l u +a d v a n t a g e +a s s i s t a n t +g o n e +s e c r e t +x +c a t h o l i c +n a m e s +l i s t e d +f i n a l l y +c a n c e r +p r o d u c e d +m e a s u r e +f l e d +l a r g e l y +d e f e a t e d +c o n g o +b a s i c +j e a n +l o s e +p r i z e +b a n g k o k +a s k +f r a n c i s c o +r e g i s t e r e d +d i s a s t e r +g o l f +i n d i v i d u a l +c o n t i n u e s +w t o +i n i t i a l +a n y o n e +q u a k e +f a c e d +s c i e n t i s t s +m o b i l e +p o s i t i o n s +f i e l d s +r e c o v e r y +m u s e u m +n u m b e r s +d e n m a r k +m a n i l a +h o l d s +c e n t +e x +e s t a b l i s h +w i d e l y +o f f i c e s +i n s i s t e d +u n i t s +k a s h m i r +r e f e r e n d u m +l o c a t e d +u p o n +a l l o w i n g +s c a l e +o p p o s e d +w a t c h +i n d i c a t e d +p a r t n e r +e a r t h q u a k e +s c a n d a l +e v e r y o n e +a p p r o a c h +t r u c k +i m p o r t a n c e +t h r e a t s +p o r t u g a l +s e x +r e c o r d s +s u p e r +s t o o d +c o n t a c t +m a t e r i a l s +v i o l e n t +p l a c e s +a n a l y s t +a d d s +a l o n e +g o e s +m o v i e +e x p e c t s +a r t +s e o u l +m e x i c a n +y e s t e r d a y +p l a n e s +n i n t h +o n l i n e +h e l i c o p t e r +i m m i g r a t i o n +p a r t n e r s +i n f r a s t r u c t u r e +b o a t +v i s i t s +n o r m a l +s t a k e +g u e r r i l l a s +m a c a o +w i l l i n g +s u n +a w a r d +t e l l +s o u t h w e s t +s p o r t +e n t e r +r e s o l v e +c h a n c e s +m i a m i +e l +e n t i r e diff --git a/training/crf/cllh_observer.cc b/training/crf/cllh_observer.cc new file mode 100644 index 00000000..4ec2fa65 --- /dev/null +++ b/training/crf/cllh_observer.cc @@ -0,0 +1,52 @@ +#include "cllh_observer.h" + +#include <cmath> +#include <cassert> + +#include "inside_outside.h" +#include "hg.h" +#include "sentence_metadata.h" + +using namespace std; + +static const double kMINUS_EPSILON = -1e-6; + +ConditionalLikelihoodObserver::~ConditionalLikelihoodObserver() {} + +void ConditionalLikelihoodObserver::NotifyDecodingStart(const SentenceMetadata&) { + cur_obj = 0; + state = 1; +} + +void ConditionalLikelihoodObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + assert(state == 1); + state = 2; + SparseVector<prob_t> cur_model_exp; + const prob_t z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_obj = log(z); +} + +void ConditionalLikelihoodObserver::NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + SparseVector<prob_t> ref_exp; + const prob_t ref_z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); + + double log_ref_z = log(ref_z); + + // rounding errors means that <0 is too strict + if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; + exit(1); + } + assert(!std::isnan(log_ref_z)); + acc_obj += (cur_obj - log_ref_z); + trg_words += smeta.GetReference().size(); +} + diff --git a/training/crf/cllh_observer.h b/training/crf/cllh_observer.h new file mode 100644 index 00000000..0de47331 --- /dev/null +++ b/training/crf/cllh_observer.h @@ -0,0 +1,26 @@ +#ifndef _CLLH_OBSERVER_H_ +#define _CLLH_OBSERVER_H_ + +#include "decoder.h" + +struct ConditionalLikelihoodObserver : public DecoderObserver { + + ConditionalLikelihoodObserver() : trg_words(), acc_obj(), cur_obj() {} + ~ConditionalLikelihoodObserver(); + + void Reset() { + acc_obj = 0; + trg_words = 0; + } + + virtual void NotifyDecodingStart(const SentenceMetadata&); + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg); + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg); + + unsigned trg_words; + double acc_obj; + double cur_obj; + int state; +}; + +#endif diff --git a/training/crf/mpi_batch_optimize.cc b/training/crf/mpi_batch_optimize.cc new file mode 100644 index 00000000..2eff07e4 --- /dev/null +++ b/training/crf/mpi_batch_optimize.cc @@ -0,0 +1,372 @@ +#include <sstream> +#include <iostream> +#include <vector> +#include <cassert> +#include <cmath> + +#include "config.h" +#ifdef HAVE_MPI +#include <boost/mpi/timer.hpp> +#include <boost/mpi.hpp> +namespace mpi = boost::mpi; +#endif + +#include <boost/shared_ptr.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "sentence_metadata.h" +#include "cllh_observer.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "stringlib.h" +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,w",po::value<string>(),"Input feature weights file") + ("training_data,t",po::value<string>(),"Training data") + ("test_data,T",po::value<string>(),"(optional) test data") + ("decoder_config,c",po::value<string>(),"Decoder configuration file") + ("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file") + ("optimization_method,m", po::value<string>()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") + ("correction_buffers,M", po::value<int>()->default_value(10), "Number of gradients for LBFGS to maintain in memory") + ("gaussian_prior,p","Use a Gaussian prior on the weights") + ("sigma_squared", po::value<double>()->default_value(1.0), "Sigma squared term for spherical Gaussian prior") + ("means,u", po::value<string>(), "(optional) file containing the means for Gaussian prior"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data")) || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + void Reset() { + acc_grad.clear(); + acc_obj = 0; + total_complete = 0; + trg_words = 0; + } + + void SetLocalGradientAndObjective(vector<double>* g, double* o) const { + *o = acc_obj; + for (SparseVector<prob_t>::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) + (*g)[it->first] = it->second.as_float(); + } + + virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { + cur_model_exp.clear(); + cur_obj = 0; + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 1); + state = 2; + const prob_t z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_obj = log(z); + cur_model_exp /= z; + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + SparseVector<prob_t> ref_exp; + const prob_t ref_z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); + ref_exp /= ref_z; + + double log_ref_z; +#if 0 + if (crf_uniform_empirical) { + log_ref_z = ref_exp.dot(feature_weights); + } else { + log_ref_z = log(ref_z); + } +#else + log_ref_z = log(ref_z); +#endif + + // rounding errors means that <0 is too strict + if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; + exit(1); + } + assert(!std::isnan(log_ref_z)); + ref_exp -= cur_model_exp; + acc_grad -= ref_exp; + acc_obj += (cur_obj - log_ref_z); + trg_words += smeta.GetReference().size(); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { + if (state == 3) { + ++total_complete; + } else { + } + } + + int total_complete; + SparseVector<prob_t> cur_model_exp; + SparseVector<prob_t> acc_grad; + double acc_obj; + double cur_obj; + unsigned trg_words; + int state; +}; + +void ReadConfig(const string& ini, vector<string>* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (!in) continue; + out->push_back(line); + } +} + +void StoreConfig(const vector<string>& cfg, istringstream* o) { + ostringstream os; + for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } + o->str(os.str()); +} + +template <typename T> +struct VectorPlus : public binary_function<vector<T>, vector<T>, vector<T> > { + vector<T> operator()(const vector<int>& a, const vector<int>& b) const { + assert(a.size() == b.size()); + vector<T> v(a.size()); + transform(a.begin(), a.end(), b.begin(), v.begin(), plus<T>()); + return v; + } +}; + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + + // load cdec.ini and set up decoder + vector<string> cdec_ini; + ReadConfig(conf["decoder_config"].as<string>(), &cdec_ini); + istringstream ini; + StoreConfig(cdec_ini, &ini); + if (rank == 0) cerr << "Loading grammar...\n"; + Decoder* decoder = new Decoder(&ini); + if (decoder->GetConf()["input"].as<string>() != "-") { + cerr << "cdec.ini must not set an input file\n"; + return 1; + } + if (rank == 0) cerr << "Done loading grammar!\n"; + + // load initial weights + if (rank == 0) { cerr << "Loading weights...\n"; } + vector<weight_t>& lambdas = decoder->CurrentWeightVector(); + Weights::InitFromFile(conf["input_weights"].as<string>(), &lambdas); + if (rank == 0) { cerr << "Done loading weights.\n"; } + + // freeze feature set (should be optional?) + const bool freeze_feature_set = true; + if (freeze_feature_set) FD::Freeze(); + + const int num_feats = FD::NumFeats(); + if (rank == 0) cerr << "Number of features: " << num_feats << endl; + lambdas.resize(num_feats); + + const bool gaussian_prior = conf.count("gaussian_prior"); + vector<weight_t> means(num_feats, 0); + if (conf.count("means")) { + if (!gaussian_prior) { + cerr << "Don't use --means without --gaussian_prior!\n"; + exit(1); + } + Weights::InitFromFile(conf["means"].as<string>(), &means); + } + boost::shared_ptr<BatchOptimizer> o; + if (rank == 0) { + const string omethod = conf["optimization_method"].as<string>(); + if (omethod == "rprop") + o.reset(new RPropOptimizer(num_feats)); // TODO add configuration + else + o.reset(new LBFGSOptimizer(num_feats, conf["correction_buffers"].as<int>())); + cerr << "Optimizer: " << o->Name() << endl; + } + double objective = 0; + vector<double> gradient(num_feats, 0.0); + vector<double> rcv_grad; + rcv_grad.clear(); + bool converged = false; + + vector<string> corpus, test_corpus; + ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus); + assert(corpus.size() > 0); + if (conf.count("test_data")) + ReadTrainingCorpus(conf["test_data"].as<string>(), rank, size, &test_corpus); + + TrainingObserver observer; + ConditionalLikelihoodObserver cllh_observer; + while (!converged) { + observer.Reset(); + cllh_observer.Reset(); +#ifdef HAVE_MPI + mpi::timer timer; + world.barrier(); +#endif + if (rank == 0) { + cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; + cerr << " Testset size: " << test_corpus.size() << " sentences / proc)\n"; + } + for (int i = 0; i < corpus.size(); ++i) + decoder->Decode(corpus[i], &observer); + cerr << " process " << rank << '/' << size << " done\n"; + fill(gradient.begin(), gradient.end(), 0); + observer.SetLocalGradientAndObjective(&gradient, &objective); + + unsigned total_words = 0; +#ifdef HAVE_MPI + double to = 0; + rcv_grad.resize(num_feats, 0.0); + mpi::reduce(world, &gradient[0], gradient.size(), &rcv_grad[0], plus<double>(), 0); + swap(gradient, rcv_grad); + rcv_grad.clear(); + + reduce(world, observer.trg_words, total_words, std::plus<unsigned>(), 0); + mpi::reduce(world, objective, to, plus<double>(), 0); + objective = to; +#else + total_words = observer.trg_words; +#endif + if (rank == 0) + cerr << "TRAINING CORPUS: ln p(f|e)=" << objective << "\t log_2 p(f|e) = " << (objective/log(2)) << "\t cond. entropy = " << (objective/log(2) / total_words) << "\t ppl = " << pow(2, (objective/log(2) / total_words)) << endl; + + for (int i = 0; i < test_corpus.size(); ++i) + decoder->Decode(test_corpus[i], &cllh_observer); + + double test_objective = 0; + unsigned test_total_words = 0; +#ifdef HAVE_MPI + reduce(world, cllh_observer.acc_obj, test_objective, std::plus<double>(), 0); + reduce(world, cllh_observer.trg_words, test_total_words, std::plus<unsigned>(), 0); +#else + test_objective = cllh_observer.acc_obj; + test_total_words = cllh_observer.trg_words; +#endif + + if (rank == 0) { // run optimizer only on rank=0 node + if (test_corpus.size()) + cerr << " TEST CORPUS: ln p(f|e)=" << test_objective << "\t log_2 p(f|e) = " << (test_objective/log(2)) << "\t cond. entropy = " << (test_objective/log(2) / test_total_words) << "\t ppl = " << pow(2, (test_objective/log(2) / test_total_words)) << endl; + if (gaussian_prior) { + const double sigsq = conf["sigma_squared"].as<double>(); + double norm = 0; + for (int k = 1; k < lambdas.size(); ++k) { + const double& lambda_k = lambdas[k]; + if (lambda_k) { + const double param = (lambda_k - means[k]); + norm += param * param; + gradient[k] += param / sigsq; + } + } + const double reg = norm / (2.0 * sigsq); + cerr << "REGULARIZATION TERM: " << reg << endl; + objective += reg; + } + cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl; + double gnorm = 0; + for (int i = 0; i < gradient.size(); ++i) + gnorm += gradient[i] * gradient[i]; + cerr << " GNORM=" << sqrt(gnorm) << endl; + vector<weight_t> old = lambdas; + int c = 0; + while (old == lambdas) { + ++c; + if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; } + o->Optimize(objective, gradient, &lambdas); + assert(c < 5); + } + old.clear(); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); + + converged = o->HasConverged(); + if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } + + string fname = "weights.cur.gz"; + if (converged) { fname = "weights.final.gz"; } + ostringstream vv; + vv << "Objective = " << objective << " (eval count=" << o->EvaluationCount() << ")"; + const string svv = vv.str(); + Weights::WriteToFile(fname, lambdas, true, &svv); + } // rank == 0 + int cint = converged; +#ifdef HAVE_MPI + mpi::broadcast(world, &lambdas[0], lambdas.size(), 0); + mpi::broadcast(world, cint, 0); + if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; } +#endif + converged = cint; + } + return 0; +} + diff --git a/training/crf/mpi_baum_welch.cc b/training/crf/mpi_baum_welch.cc new file mode 100644 index 00000000..d69b1769 --- /dev/null +++ b/training/crf/mpi_baum_welch.cc @@ -0,0 +1,316 @@ +#include <sstream> +#include <iostream> +#include <vector> +#include <cassert> +#include <cmath> + +#include "config.h" +#ifdef HAVE_MPI +#include <boost/mpi/timer.hpp> +#include <boost/mpi.hpp> +namespace mpi = boost::mpi; +#endif + +#include <boost/unordered_map.hpp> +#include <boost/functional/hash.hpp> +#include <boost/shared_ptr.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "sentence_metadata.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "stringlib.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,w",po::value<string>(),"Input feature weights file") + ("iterations,n",po::value<unsigned>()->default_value(50), "Number of training iterations") + ("training_data,t",po::value<string>(),"Training data") + ("decoder_config,c",po::value<string>(),"Decoder configuration file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data")) || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + void Reset() { + acc_grad.clear(); + acc_obj = 0; + total_complete = 0; + trg_words = 0; + } + + void SetLocalGradientAndObjective(vector<double>* g, double* o) const { + *o = acc_obj; + for (SparseVector<double>::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) + (*g)[it->first] = it->second; + } + + virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 1); + trg_words += smeta.GetSourceLength(); + state = 2; + SparseVector<prob_t> exps; + const prob_t z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &exps); + exps /= z; + for (SparseVector<prob_t>::iterator it = exps.begin(); it != exps.end(); ++it) + acc_grad.add_value(it->first, it->second.as_float()); + + acc_obj += log(z); + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + cerr << "Shouldn't get an alignment forest!\n"; + abort(); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { + ++total_complete; + } + + int total_complete; + SparseVector<double> acc_grad; + double acc_obj; + unsigned trg_words; + int state; +}; + +void ReadConfig(const string& ini, vector<string>* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (!in) continue; + out->push_back(line); + } +} + +void StoreConfig(const vector<string>& cfg, istringstream* o) { + ostringstream os; + for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } + o->str(os.str()); +} + +#if 0 +template <typename T> +struct VectorPlus : public binary_function<vector<T>, vector<T>, vector<T> > { + vector<T> operator()(const vector<int>& a, const vector<int>& b) const { + assert(a.size() == b.size()); + vector<T> v(a.size()); + transform(a.begin(), a.end(), b.begin(), v.begin(), plus<T>()); + return v; + } +}; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + const unsigned iterations = conf["iterations"].as<unsigned>(); + + // load cdec.ini and set up decoder + vector<string> cdec_ini; + ReadConfig(conf["decoder_config"].as<string>(), &cdec_ini); + istringstream ini; + StoreConfig(cdec_ini, &ini); + Decoder* decoder = new Decoder(&ini); + if (decoder->GetConf()["input"].as<string>() != "-") { + cerr << "cdec.ini must not set an input file\n"; + return 1; + } + + // load initial weights + if (rank == 0) { cerr << "Loading weights...\n"; } + vector<weight_t>& lambdas = decoder->CurrentWeightVector(); + Weights::InitFromFile(conf["input_weights"].as<string>(), &lambdas); + if (rank == 0) { cerr << "Done loading weights.\n"; } + + // freeze feature set (should be optional?) + const bool freeze_feature_set = true; + if (freeze_feature_set) FD::Freeze(); + + const int num_feats = FD::NumFeats(); + if (rank == 0) cerr << "Number of features: " << num_feats << endl; + lambdas.resize(num_feats); + + vector<double> gradient(num_feats, 0.0); + vector<double> rcv_grad; + rcv_grad.clear(); + bool converged = false; + + vector<string> corpus, test_corpus; + ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus); + assert(corpus.size() > 0); + if (conf.count("test_data")) + ReadTrainingCorpus(conf["test_data"].as<string>(), rank, size, &test_corpus); + + // build map from feature id to the accumulator that should normalize + boost::unordered_map<std::string, boost::unordered_map<int, double>, boost::hash<std::string> > ccs; + vector<boost::unordered_map<int, double>* > cpd_to_acc; + if (rank == 0) { + cpd_to_acc.resize(num_feats); + for (unsigned f = 1; f < num_feats; ++f) { + string normalizer; + //0 ||| 7 9 ||| Bi:BOS_7=1 Bi:7_9=1 Bi:9_EOS=1 Id:a:7=1 Uni:7=1 Id:b:9=1 Uni:9=1 ||| 0 + const string& fstr = FD::Convert(f); + if (fstr.find("Bi:") == 0) { + size_t pos = fstr.rfind('_'); + if (pos < fstr.size()) + normalizer = fstr.substr(0, pos); + } else if (fstr.find("Id:") == 0) { + size_t pos = fstr.rfind(':'); + if (pos < fstr.size()) { + normalizer = "Emit:"; + normalizer += fstr.substr(pos); + } + } + if (normalizer.size() > 0) { + boost::unordered_map<int, double>& acc = ccs[normalizer]; + cpd_to_acc[f] = &acc; + } + } + } + + TrainingObserver observer; + int iteration = 0; + while (!converged) { + ++iteration; + observer.Reset(); +#ifdef HAVE_MPI + mpi::timer timer; + world.barrier(); +#endif + if (rank == 0) { + cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; + cerr << " Testset size: " << test_corpus.size() << " sentences / proc)\n"; + for(boost::unordered_map<string, boost::unordered_map<int,double>, boost::hash<string> >::iterator it = ccs.begin(); it != ccs.end(); ++it) + it->second.clear(); + } + for (int i = 0; i < corpus.size(); ++i) + decoder->Decode(corpus[i], &observer); + cerr << " process " << rank << '/' << size << " done\n"; + fill(gradient.begin(), gradient.end(), 0); + double objective = 0; + observer.SetLocalGradientAndObjective(&gradient, &objective); + + unsigned total_words = 0; +#ifdef HAVE_MPI + double to = 0; + rcv_grad.resize(num_feats, 0.0); + mpi::reduce(world, &gradient[0], gradient.size(), &rcv_grad[0], plus<double>(), 0); + swap(gradient, rcv_grad); + rcv_grad.clear(); + + reduce(world, observer.trg_words, total_words, std::plus<unsigned>(), 0); + mpi::reduce(world, objective, to, plus<double>(), 0); + objective = to; +#else + total_words = observer.trg_words; +#endif + if (rank == 0) { // run optimizer only on rank=0 node + cerr << "TRAINING CORPUS: ln p(x)=" << objective << "\t log_2 p(x) = " << (objective/log(2)) << "\t cross entropy = " << (objective/log(2) / total_words) << "\t ppl = " << pow(2, (-objective/log(2) / total_words)) << endl; + for (unsigned f = 1; f < num_feats; ++f) { + boost::unordered_map<int, double>* m = cpd_to_acc[f]; + if (m && gradient[f]) { + (*m)[f] += gradient[f]; + } + for(boost::unordered_map<string, boost::unordered_map<int,double>, boost::hash<string> >::iterator it = ccs.begin(); it != ccs.end(); ++it) { + const boost::unordered_map<int,double>& ccs = it->second; + double z = 0; + for (boost::unordered_map<int,double>::const_iterator ci = ccs.begin(); ci != ccs.end(); ++ci) + z += ci->second + 1e-09; + double lz = log(z); + for (boost::unordered_map<int,double>::const_iterator ci = ccs.begin(); ci != ccs.end(); ++ci) + lambdas[ci->first] = log(ci->second + 1e-09) - lz; + } + } + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); + + converged = (iteration == iterations); + + string fname = "weights.cur.gz"; + if (converged) { fname = "weights.final.gz"; } + ostringstream vv; + vv << "Objective = " << objective << " (eval count=" << iteration << ")"; + const string svv = vv.str(); + Weights::WriteToFile(fname, lambdas, true, &svv); + } // rank == 0 + int cint = converged; +#ifdef HAVE_MPI + mpi::broadcast(world, &lambdas[0], lambdas.size(), 0); + mpi::broadcast(world, cint, 0); + if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; } +#endif + converged = cint; + } + return 0; +} + diff --git a/training/crf/mpi_compute_cllh.cc b/training/crf/mpi_compute_cllh.cc new file mode 100644 index 00000000..066389d0 --- /dev/null +++ b/training/crf/mpi_compute_cllh.cc @@ -0,0 +1,134 @@ +#include <iostream> +#include <vector> +#include <cassert> +#include <cmath> + +#include "config.h" +#ifdef HAVE_MPI +#include <boost/mpi.hpp> +#endif +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "cllh_observer.h" +#include "sentence_metadata.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("weights,w",po::value<string>(),"Input feature weights file") + ("training_data,t",po::value<string>(),"Training data corpus") + ("decoder_config,c",po::value<string>(),"Decoder configuration file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadInstances(const string& fname, int rank, int size, vector<string>* c) { + assert(fname != "-"); + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as<string>()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as<string>() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + // load weights + vector<weight_t>& weights = decoder.CurrentWeightVector(); + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as<string>(), &weights); + + vector<string> corpus; + ReadInstances(conf["training_data"].as<string>(), rank, size, &corpus); + assert(corpus.size() > 0); + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + ConditionalLikelihoodObserver observer; + for (int i = 0; i < corpus.size(); ++i) + decoder.Decode(corpus[i], &observer); + + double objective = 0; + unsigned total_words = 0; +#ifdef HAVE_MPI + reduce(world, observer.acc_obj, objective, std::plus<double>(), 0); + reduce(world, observer.trg_words, total_words, std::plus<unsigned>(), 0); +#else + objective = observer.acc_obj; +#endif + + if (rank == 0) { + cout << "CONDITIONAL LOG_e LIKELIHOOD: " << objective << endl; + cout << "CONDITIONAL LOG_2 LIKELIHOOD: " << (objective/log(2)) << endl; + cout << " CONDITIONAL ENTROPY: " << (objective/log(2) / total_words) << endl; + cout << " PERPLEXITY: " << pow(2, (objective/log(2) / total_words)) << endl; + } + + return 0; +} + diff --git a/training/crf/mpi_extract_features.cc b/training/crf/mpi_extract_features.cc new file mode 100644 index 00000000..6750aa15 --- /dev/null +++ b/training/crf/mpi_extract_features.cc @@ -0,0 +1,151 @@ +#include <iostream> +#include <sstream> +#include <vector> +#include <cassert> + +#include "config.h" +#ifdef HAVE_MPI +#include <boost/mpi.hpp> +#endif +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "ff_register.h" +#include "verbose.h" +#include "filelib.h" +#include "fdict.h" +#include "decoder.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("training_data,t",po::value<string>(),"Training data corpus") + ("decoder_config,c",po::value<string>(),"Decoder configuration file") + ("weights,w", po::value<string>(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") + ("output_prefix,o",po::value<string>()->default_value("features"),"Output path prefix"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the feature strings encountered.\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + } +}; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as<string>()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as<string>() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + if (FD::UsingPerfectHashFunction()) { + cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; + return 1; + } + + // load optional weights + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as<string>(), &decoder.CurrentWeightVector()); + + vector<string> corpus; + ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus); + assert(corpus.size() > 0); + + TrainingObserver observer; + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + for (int i = 0; i < corpus.size(); ++i) + decoder.Decode(corpus[i], &observer); + + { + ostringstream os; + os << conf["output_prefix"].as<string>() << '.' << rank << "_of_" << size; + WriteFile wf(os.str()); + ostream& out = *wf.stream(); + const unsigned num_feats = FD::NumFeats(); + for (unsigned i = 1; i < num_feats; ++i) { + out << FD::Convert(i) << endl; + } + cerr << "Wrote " << os.str() << endl; + } + +#ifdef HAVE_MPI + world.barrier(); +#else +#endif + + return 0; +} + diff --git a/training/crf/mpi_extract_reachable.cc b/training/crf/mpi_extract_reachable.cc new file mode 100644 index 00000000..2a7c2b9d --- /dev/null +++ b/training/crf/mpi_extract_reachable.cc @@ -0,0 +1,163 @@ +#include <iostream> +#include <sstream> +#include <vector> +#include <cassert> + +#include "config.h" +#ifdef HAVE_MPI +#include <boost/mpi.hpp> +#endif +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "ff_register.h" +#include "verbose.h" +#include "filelib.h" +#include "fdict.h" +#include "decoder.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("training_data,t",po::value<string>(),"Training data corpus") + ("decoder_config,c",po::value<string>(),"Decoder configuration file") + ("weights,w", po::value<string>(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") + ("output_prefix,o",po::value<string>()->default_value("reachable"),"Output path prefix"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the inputs that produce reachable parallel parses.\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct ReachabilityObserver : public DecoderObserver { + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + reachable = false; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + reachable = true; + } + + bool reachable; +}; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as<string>()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as<string>() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + if (FD::UsingPerfectHashFunction()) { + cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; + return 1; + } + + // load optional weights + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as<string>(), &decoder.CurrentWeightVector()); + + vector<string> corpus; + ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus); + assert(corpus.size() > 0); + + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + size_t num_reached = 0; + { + ostringstream os; + os << conf["output_prefix"].as<string>() << '.' << rank << "_of_" << size; + WriteFile wf(os.str()); + ostream& out = *wf.stream(); + ReachabilityObserver observer; + for (int i = 0; i < corpus.size(); ++i) { + decoder.Decode(corpus[i], &observer); + if (observer.reachable) { + out << corpus[i] << endl; + ++num_reached; + } + corpus[i].clear(); + } + cerr << "Shard " << rank << '/' << size << " finished, wrote " + << num_reached << " instances to " << os.str() << endl; + } + + size_t total = 0; +#ifdef HAVE_MPI + reduce(world, num_reached, total, std::plus<double>(), 0); +#else + total = num_reached; +#endif + if (rank == 0) { + cerr << "-----------------------------------------\n"; + cerr << "TOTAL = " << total << " instances\n"; + } + return 0; +} + diff --git a/training/crf/mpi_flex_optimize.cc b/training/crf/mpi_flex_optimize.cc new file mode 100644 index 00000000..b52decdc --- /dev/null +++ b/training/crf/mpi_flex_optimize.cc @@ -0,0 +1,386 @@ +#include <sstream> +#include <iostream> +#include <fstream> +#include <vector> +#include <cassert> +#include <cmath> + +#include <boost/shared_ptr.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "stringlib.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" +#include "sampler.h" + +#ifdef HAVE_MPI +#include <boost/mpi/timer.hpp> +#include <boost/mpi.hpp> +namespace mpi = boost::mpi; +#endif + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("cdec_config,c",po::value<string>(),"Decoder configuration file") + ("weights,w",po::value<string>(),"Initial feature weights") + ("training_data,d",po::value<string>(),"Training data") + ("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(6), "Number of training instances evaluated per processor in each minibatch") + ("minibatch_iterations,i", po::value<unsigned>()->default_value(10), "Number of optimization iterations per minibatch") + ("iterations,I", po::value<unsigned>()->default_value(50), "Number of passes through the training data before termination") + ("regularization_strength,C", po::value<double>()->default_value(0.2), "Regularization strength") + ("time_series_strength,T", po::value<double>()->default_value(0.0), "Time series regularization strength") + ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") + ("lbfgs_memory_buffers,M", po::value<unsigned>()->default_value(10), "Number of memory buffers for LBFGS history"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("cdec_config")) { + cerr << "LBFGS minibatch online optimizer (MPI support " +#if HAVE_MPI + << "enabled" +#else + << "not enabled" +#endif + << ")\n" << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c, vector<int>* order) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int id = 0; + while(in) { + getline(in, line); + if (!in) break; + if (id % size == rank) { + c->push_back(line); + order->push_back(id); + } + ++id; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct CopyHGsObserver : public DecoderObserver { + Hypergraph* hg_; + Hypergraph* gold_hg_; + + // this can free up some memory + void RemoveRules(Hypergraph* h) { + for (unsigned i = 0; i < h->edges_.size(); ++i) + h->edges_[i].rule_.reset(); + } + + void SetCurrentHypergraphs(Hypergraph* h, Hypergraph* gold_h) { + hg_ = h; + gold_hg_ = gold_h; + } + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + *hg_ = *hg; + RemoveRules(hg_); + assert(state == 1); + state = 2; + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata&, Hypergraph* hg) { + assert(state == 2); + state = 3; + *gold_hg_ = *hg; + RemoveRules(gold_hg_); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata&) { + if (state == 3) { + } else { + hg_->clear(); + gold_hg_->clear(); + } + } + + int state; +}; + +void ReadConfig(const string& ini, istringstream* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + ostringstream os; + while(in) { + string line; + getline(in, line); + if (!in) continue; + os << line << endl; + } + out->str(os.str()); +} + +#ifdef HAVE_MPI +namespace boost { namespace mpi { + template<> + struct is_commutative<std::plus<SparseVector<double> >, SparseVector<double> > + : mpl::true_ { }; +} } // end namespace boost::mpi +#endif + +void AddGrad(const SparseVector<prob_t> x, double s, SparseVector<double>* acc) { + for (SparseVector<prob_t>::const_iterator it = x.begin(); it != x.end(); ++it) + acc->add_value(it->first, it->second.as_float() * s); +} + +double PNorm(const vector<double>& v, const double p) { + double acc = 0; + for (int i = 0; i < v.size(); ++i) + acc += pow(v[i], p); + return pow(acc, 1.0 / p); +} + +void VV(ostream&os, const vector<double>& v) { + for (int i = 1; i < v.size(); ++i) + if (v[i]) os << FD::Convert(i) << "=" << v[i] << " "; +} + +double ApplyRegularizationTerms(const double C, + const double T, + const vector<double>& weights, + const vector<double>& prev_weights, + double* g) { + double reg = 0; + for (size_t i = 0; i < weights.size(); ++i) { + const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); + const double& w_i = weights[i]; + reg += C * w_i * w_i; + g[i] += 2 * C * w_i; + + reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); + g[i] += 2 * T * (w_i - prev_w_i); + } + return reg; +} + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + MT19937* rng = NULL; + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return 1; + + boost::shared_ptr<BatchOptimizer> o; + const unsigned lbfgs_memory_buffers = conf["lbfgs_memory_buffers"].as<unsigned>(); + const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>(); + const unsigned minibatch_iterations = conf["minibatch_iterations"].as<unsigned>(); + const double regularization_strength = conf["regularization_strength"].as<double>(); + const double time_series_strength = conf["time_series_strength"].as<double>(); + const bool use_time_series_reg = time_series_strength > 0.0; + const unsigned max_iteration = conf["iterations"].as<unsigned>(); + + vector<string> corpus; + vector<int> ids; + ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus, &ids); + assert(corpus.size() > 0); + + if (size_per_proc > corpus.size()) { + cerr << "Minibatch size (per processor) must be smaller or equal to the local corpus size!\n"; + return 1; + } + + // initialize decoder (loads hash functions if necessary) + istringstream ins; + ReadConfig(conf["cdec_config"].as<string>(), &ins); + Decoder decoder(&ins); + + // load initial weights + vector<weight_t> prev_weights; + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as<string>(), &prev_weights); + + if (conf.count("random_seed")) + rng = new MT19937(conf["random_seed"].as<uint32_t>()); + else + rng = new MT19937; + + size_t total_corpus_size = 0; +#ifdef HAVE_MPI + reduce(world, corpus.size(), total_corpus_size, std::plus<size_t>(), 0); +#else + total_corpus_size = corpus.size(); +#endif + + if (rank == 0) + cerr << "Total corpus size: " << total_corpus_size << endl; + + CopyHGsObserver observer; + + int write_weights_every_ith = 100; // TODO configure + int titer = -1; + + vector<weight_t>& cur_weights = decoder.CurrentWeightVector(); + if (use_time_series_reg) { + cur_weights = prev_weights; + } else { + cur_weights.swap(prev_weights); + prev_weights.clear(); + } + + int iter = -1; + bool converged = false; + vector<double> gg; + while (!converged) { +#ifdef HAVE_MPI + mpi::timer timer; +#endif + ++iter; ++titer; + if (rank == 0) { + converged = (iter == max_iteration); + string fname = "weights.cur.gz"; + if (iter % write_weights_every_ith == 0) { + ostringstream o; o << "weights.epoch_" << iter << ".gz"; + fname = o.str(); + } + if (converged) { fname = "weights.final.gz"; } + ostringstream vv; + vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())); + const string svv = vv.str(); + Weights::WriteToFile(fname, cur_weights, true, &svv); + } + + vector<Hypergraph> hgs(size_per_proc); + vector<Hypergraph> gold_hgs(size_per_proc); + for (int i = 0; i < size_per_proc; ++i) { + int ei = corpus.size() * rng->next(); + int id = ids[ei]; + observer.SetCurrentHypergraphs(&hgs[i], &gold_hgs[i]); + decoder.SetId(id); + decoder.Decode(corpus[ei], &observer); + } + + SparseVector<double> local_grad, g; + double local_obj = 0; + o.reset(); + for (unsigned mi = 0; mi < minibatch_iterations; ++mi) { + local_grad.clear(); + g.clear(); + local_obj = 0; + + for (unsigned i = 0; i < size_per_proc; ++i) { + Hypergraph& hg = hgs[i]; + Hypergraph& hg_gold = gold_hgs[i]; + if (hg.edges_.size() < 2) continue; + + hg.Reweight(cur_weights); + hg_gold.Reweight(cur_weights); + SparseVector<prob_t> model_exp, gold_exp; + const prob_t z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(hg, &model_exp); + local_obj += log(z); + model_exp /= z; + AddGrad(model_exp, 1.0, &local_grad); + model_exp.clear(); + + const prob_t goldz = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(hg_gold, &gold_exp); + local_obj -= log(goldz); + + if (log(z) - log(goldz) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_gold_z: " << log(z) << " " << log(goldz) << endl; + return 1; + } + + gold_exp /= goldz; + AddGrad(gold_exp, -1.0, &local_grad); + } + + double obj = 0; +#ifdef HAVE_MPI + reduce(world, local_obj, obj, std::plus<double>(), 0); + reduce(world, local_grad, g, std::plus<SparseVector<double> >(), 0); +#else + obj = local_obj; + g.swap(local_grad); +#endif + local_grad.clear(); + if (rank == 0) { + // g /= (size_per_proc * size); + if (!o) + o.reset(new LBFGSOptimizer(FD::NumFeats(), lbfgs_memory_buffers)); + gg.clear(); + gg.resize(FD::NumFeats()); + if (gg.size() != cur_weights.size()) { cur_weights.resize(gg.size()); } + for (SparseVector<double>::iterator it = g.begin(); it != g.end(); ++it) + if (it->first) { gg[it->first] = it->second; } + g.clear(); + double r = ApplyRegularizationTerms(regularization_strength, + time_series_strength, // * (iter == 0 ? 0.0 : 1.0), + cur_weights, + prev_weights, + &gg[0]); + obj += r; + if (mi == 0 || mi == (minibatch_iterations - 1)) { + if (!mi) cerr << iter << ' '; else cerr << ' '; + cerr << "OBJ=" << obj << " (REG=" << r << ")" << " |g|=" << PNorm(gg, 2) << " |w|=" << PNorm(cur_weights, 2); + if (mi > 0) cerr << endl << flush; else cerr << ' '; + } else { cerr << '.' << flush; } + // cerr << "w = "; VV(cerr, cur_weights); cerr << endl; + // cerr << "g = "; VV(cerr, gg); cerr << endl; + o->Optimize(obj, gg, &cur_weights); + } +#ifdef HAVE_MPI + broadcast(world, cur_weights, 0); + broadcast(world, converged, 0); + world.barrier(); +#endif + } + prev_weights = cur_weights; + } + return 0; +} diff --git a/training/crf/mpi_online_optimize.cc b/training/crf/mpi_online_optimize.cc new file mode 100644 index 00000000..9e1ae34c --- /dev/null +++ b/training/crf/mpi_online_optimize.cc @@ -0,0 +1,384 @@ +#include <sstream> +#include <iostream> +#include <fstream> +#include <vector> +#include <cassert> +#include <cmath> +#include <tr1/memory> +#include <ctime> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "stringlib.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "online_optimizer.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" +#include "sampler.h" + +#ifdef HAVE_MPI +#include <boost/mpi/timer.hpp> +#include <boost/mpi.hpp> +namespace mpi = boost::mpi; +#endif + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,w",po::value<string>(),"Input feature weights file") + ("frozen_features,z",po::value<string>(), "List of features not to optimize") + ("training_data,t",po::value<string>(),"Training data corpus") + ("training_agenda,a",po::value<string>(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively") + ("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(5), "Number of training instances evaluated per processor in each minibatch") + ("optimization_method,m", po::value<string>()->default_value("sgd"), "Optimization method (sgd)") + ("max_walltime", po::value<unsigned>(), "Maximum walltime to run (in minutes)") + ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") + ("eta_0,e", po::value<double>()->default_value(0.2), "Initial learning rate for SGD (eta_0)") + ("L1,1","Use L1 regularization") + ("regularization_strength,C", po::value<double>()->default_value(1.0), "Regularization strength (C)"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("training_agenda")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c, vector<int>* order) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int id = 0; + while(in) { + getline(in, line); + if (!in) break; + if (id % size == rank) { + c->push_back(line); + order->push_back(id); + } + ++id; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + void Reset() { + acc_grad.clear(); + acc_obj = 0; + total_complete = 0; + } + + void SetLocalGradientAndObjective(vector<double>* g, double* o) const { + *o = acc_obj; + for (SparseVector<prob_t>::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) + (*g)[it->first] = it->second.as_float(); + } + + virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { + cur_model_exp.clear(); + cur_obj = 0; + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 1); + state = 2; + const prob_t z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_obj = log(z); + cur_model_exp /= z; + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + SparseVector<prob_t> ref_exp; + const prob_t ref_z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); + ref_exp /= ref_z; + + double log_ref_z; +#if 0 + if (crf_uniform_empirical) { + log_ref_z = ref_exp.dot(feature_weights); + } else { + log_ref_z = log(ref_z); + } +#else + log_ref_z = log(ref_z); +#endif + + // rounding errors means that <0 is too strict + if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; + exit(1); + } + assert(!std::isnan(log_ref_z)); + ref_exp -= cur_model_exp; + acc_grad += ref_exp; + acc_obj += (cur_obj - log_ref_z); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { + if (state == 3) { + ++total_complete; + } else { + } + } + + void GetGradient(SparseVector<double>* g) const { + g->clear(); + for (SparseVector<prob_t>::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) + g->set_value(it->first, it->second.as_float()); + } + + int total_complete; + SparseVector<prob_t> cur_model_exp; + SparseVector<prob_t> acc_grad; + double acc_obj; + double cur_obj; + int state; +}; + +#ifdef HAVE_MPI +namespace boost { namespace mpi { + template<> + struct is_commutative<std::plus<SparseVector<double> >, SparseVector<double> > + : mpl::true_ { }; +} } // end namespace boost::mpi +#endif + +bool LoadAgenda(const string& file, vector<pair<string, int> >* a) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (!in) break; + if (line.empty()) continue; + if (line[0] == '#') continue; + int sc = 0; + if (line.size() < 3) return false; + for (int i = 0; i < line.size(); ++i) { if (line[i] == ' ') ++sc; } + if (sc != 1) { cerr << "Too many spaces in line: " << line << endl; return false; } + size_t d = line.find(" "); + pair<string, int> x; + x.first = line.substr(0,d); + x.second = atoi(line.substr(d+1).c_str()); + a->push_back(x); + if (!FileExists(x.first)) { + cerr << "Can't find file " << x.first << endl; + return false; + } + } + return true; +} + +int main(int argc, char** argv) { + cerr << "THIS SOFTWARE IS DEPRECATED YOU SHOULD USE mpi_flex_optimize\n"; +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + std::tr1::shared_ptr<MT19937> rng; + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return 1; + + vector<pair<string, int> > agenda; + if (!LoadAgenda(conf["training_agenda"].as<string>(), &agenda)) + return 1; + if (rank == 0) + cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; + + assert(agenda.size() > 0); + + if (1) { // hack to load the feature hash functions -- TODO this should not be in cdec.ini + const string& cur_config = agenda[0].first; + const unsigned max_iteration = agenda[0].second; + ReadFile ini_rf(cur_config); + Decoder decoder(ini_rf.stream()); + } + + // load initial weights + vector<weight_t> init_weights; + if (conf.count("input_weights")) + Weights::InitFromFile(conf["input_weights"].as<string>(), &init_weights); + + vector<int> frozen_fids; + if (conf.count("frozen_features")) { + ReadFile rf(conf["frozen_features"].as<string>()); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (line.empty()) continue; + if (line[0] == ' ' || line[line.size() - 1] == ' ') { line = Trim(line); } + frozen_fids.push_back(FD::Convert(line)); + } + if (rank == 0) cerr << "Freezing " << frozen_fids.size() << " features.\n"; + } + + vector<string> corpus; + vector<int> ids; + ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus, &ids); + assert(corpus.size() > 0); + + std::tr1::shared_ptr<OnlineOptimizer> o; + std::tr1::shared_ptr<LearningRateSchedule> lr; + + const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>(); + if (size_per_proc > corpus.size()) { + cerr << "Minibatch size must be smaller than corpus size!\n"; + return 1; + } + + size_t total_corpus_size = 0; +#ifdef HAVE_MPI + reduce(world, corpus.size(), total_corpus_size, std::plus<size_t>(), 0); +#else + total_corpus_size = corpus.size(); +#endif + + if (rank == 0) { + cerr << "Total corpus size: " << total_corpus_size << endl; + const unsigned batch_size = size_per_proc * size; + // TODO config + lr.reset(new ExponentialDecayLearningRate(batch_size, conf["eta_0"].as<double>())); + + const string omethod = conf["optimization_method"].as<string>(); + if (omethod == "sgd") { + const double C = conf["regularization_strength"].as<double>(); + o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C, frozen_fids)); + } else { + assert(!"fail"); + } + } + if (conf.count("random_seed")) + rng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); + else + rng.reset(new MT19937); + + SparseVector<double> x; + Weights::InitSparseVector(init_weights, &x); + TrainingObserver observer; + + int write_weights_every_ith = 100; // TODO configure + int titer = -1; + + unsigned timeout = 0; + if (conf.count("max_walltime")) timeout = 60 * conf["max_walltime"].as<unsigned>(); + const time_t start_time = time(NULL); + for (int ai = 0; ai < agenda.size(); ++ai) { + const string& cur_config = agenda[ai].first; + const unsigned max_iteration = agenda[ai].second; + if (rank == 0) + cerr << "STARTING TRAINING EPOCH " << (ai+1) << ". CONFIG=" << cur_config << endl; + // load cdec.ini and set up decoder + ReadFile ini_rf(cur_config); + Decoder decoder(ini_rf.stream()); + vector<weight_t>& lambdas = decoder.CurrentWeightVector(); + if (ai == 0) { lambdas.swap(init_weights); init_weights.clear(); } + + if (rank == 0) + o->ResetEpoch(); // resets the learning rate-- TODO is this good? + + int iter = -1; + bool converged = false; + while (!converged) { +#ifdef HAVE_MPI + mpi::timer timer; +#endif + x.init_vector(&lambdas); + ++iter; ++titer; + observer.Reset(); + if (rank == 0) { + converged = (iter == max_iteration); + Weights::SanityCheck(lambdas); + static int cc = 0; ++cc; if (cc > 1) { Weights::ShowLargestFeatures(lambdas); } + string fname = "weights.cur.gz"; + if (iter % write_weights_every_ith == 0) { + ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; + fname = o.str(); + } + const time_t cur_time = time(NULL); + if (timeout) { + if ((cur_time - start_time) > timeout) converged = true; + } + if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; } + ostringstream vv; + double minutes = (cur_time - start_time) / 60.0; + vv << "total walltime=" << minutes << "min iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer); + const string svv = vv.str(); + cerr << svv << endl; + Weights::WriteToFile(fname, lambdas, true, &svv); + } + + for (int i = 0; i < size_per_proc; ++i) { + int ei = corpus.size() * rng->next(); + int id = ids[ei]; + decoder.SetId(id); + decoder.Decode(corpus[ei], &observer); + } + SparseVector<double> local_grad, g; + observer.GetGradient(&local_grad); +#ifdef HAVE_MPI + reduce(world, local_grad, g, std::plus<SparseVector<double> >(), 0); +#else + g.swap(local_grad); +#endif + local_grad.clear(); + if (rank == 0) { + g /= (size_per_proc * size); + o->UpdateWeights(g, FD::NumFeats(), &x); + } +#ifdef HAVE_MPI + broadcast(world, x, 0); + broadcast(world, converged, 0); + world.barrier(); + if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; } +#endif + } + } + return 0; +} |