summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbold_reranking.rb75
-rw-r--r--example/example.ini2
2 files changed, 34 insertions, 43 deletions
diff --git a/bold_reranking.rb b/bold_reranking.rb
index 3041ced..8f8cfab 100755
--- a/bold_reranking.rb
+++ b/bold_reranking.rb
@@ -29,7 +29,7 @@ class FeatureFactory
@filter_features = false
if cfg['filter_features']
@filter_features = true
- @stopwords_target = ReadFile.new(cfg['filter_features']).readlines.map{ |i| i.strip.downcase }
+ @stopwords_target = ReadFile.readlines(cfg['filter_features']).map{ |i| i.strip.downcase }
end
end
@@ -107,7 +107,7 @@ class MosesKbestEntryWithPhraseAlignment < Translation
def initialize
super
- @other_score = -1.0/0
+ @scores[:rr] = -1.0/0
end
def get_phrases
@@ -126,11 +126,8 @@ class MosesKbestEntryWithPhraseAlignment < Translation
@raw.scan(/\|-?\d+\||\|\d+-\d+\|/).map{ |i| i[1..-2] }.map{ |i| _span i }
end
- def other_score model=nil
- if model
- @other_score = model.dot(@f)
- end
- return @other_score
+ def score model
+ @scores[:rr] = model.dot(@f)
end
end
@@ -140,34 +137,27 @@ class ConstrainedSearchOracle < MosesKbestEntryWithPhraseAlignment
@id = -1
@raw = s.strip.split(' : ', 2)[1].gsub(/(\[|\])/, '|')
@s = @raw.gsub(/\s*\|\d+-\d+\||\|-?\d+\|\s*/, ' ').gsub(/\s+/, ' ')
- @score = 1.0/0
- @other_score = -1.0/0
+ @scores[:rr] = -1.0/0
end
end
-def structured_update model, hypothesis, oracle
+def structured_update model, hypothesis, oracle, learning_rate
if hypothesis.s != oracle.s
- model += oracle.f - hypothesis.f
+ model += (oracle.f - hypothesis.f) * learning_rate
return [model, 1]
end
return [model, 0]
end
-def ranking_update w, hypothesis, oracle
- if oracle.other_score <= hypothesis.other_score \
+def ranking_update w, hypothesis, oracle, learning_rate
+ if oracle.scores[:rr] <= hypothesis.scores[:rr] \
&& oracle.s != hypothesis.s
- model += oracle.f - hypothesis.f
+ model += (oracle.f - hypothesis.f) * learning_rate
return [model, 1]
end
return [model, 0]
end
-def write_model fn, w
- f = WriteFile.new fn
- f.write w.to_s+"\n"
- f.close
-end
-
def read_additional_phrase_pairs fn
f = ReadFile.new fn
add = {}
@@ -192,17 +182,19 @@ end
def main
usage if ARGV.size != 1
- cfg = read_cfg ARGV[0]
-
- sources = ReadFile.new(cfg['sources']).readlines
- oracles = ReadFile.new(cfg['oracles']).readlines
- kbest_lists = read_kbest_lists cfg['kbest_lists'], MosesKbestEntryWithPhraseAlignment
- iterations = cfg['iterate'].to_i
- output = WriteFile.new cfg['output']
- output_model = cfg['output_model']
- silent = true if cfg['silent']
- verbose = true if cfg['verbose']
- cheat = true if cfg['cheat']
+ cfg = read_config ARGV[0]
+
+ sources = ReadFile.readlines cfg['sources']
+ oracles = ReadFile.readlines cfg['oracles']
+ kbest_lists = read_kbest_lists cfg['kbest_lists'], MosesKbestEntryWithPhraseAlignment
+ learning_rate = cfg['learning_rate'].to_f
+ learning_rate = 1.0 if !learning_rate
+ iterations = cfg['iterate'].to_i
+ output = WriteFile.new cfg['output']
+ output_model = cfg['output_model']
+ silent = true if cfg['silent']
+ verbose = true if cfg['verbose']
+ cheat = true if cfg['cheat']
additional_phrase_pairs = nil
if cfg['additional_phrase_pairs']
@@ -218,7 +210,7 @@ def main
model = SparseVector.new
if cfg['init_model']
- model.from_s ReadFile.new(cfg['init_model']).read
+ model.from_s ReadFile.read cfg['init_model']
end
sz = sources.size
@@ -235,26 +227,25 @@ def main
kbest = kbest_lists[j]
kbest.each { |k|
k.f = ff.produce k, sources[j]
- k.other_score model
+ k.score model
}
- hypothesis = kbest[ kbest.map{ |k| k.other_score }.max_index ]
+ hypothesis = kbest[ kbest.map{ |k| k.scores[:rr] }.max_index ]
if !cheat
output.write "#{hypothesis.s}\n"
end
- oracle = ConstrainedSearchOracle.new
- oracle.from_s oracles[j]
+ oracle = ConstrainedSearchOracle.from_s oracles[j]
oracle.f = ff.produce oracle, sources[j]
- oracle.other_score model
+ oracle.score model
err = 0
case cfg['update']
when 'structured'
- model, err = structured_update model, hypothesis, oracle
+ model, err = structured_update model, hypothesis, oracle, learning_rate
when 'ranking'
- model, err = ranking_update model, hypothesis, oracle
+ model, err = ranking_update model, hypothesis, oracle, learning_rate
else
STDERR.write "Don't know update method '#{cfg['update']}', exiting.\n"
exit 1
@@ -262,8 +253,8 @@ def main
overall_errors += err
if cheat
- kbest.each { |k| k.other_score model }
- hypothesis = kbest[ kbest.map{ |k| k.other_score }.max_index ]
+ kbest.each { |k| k.score model }
+ hypothesis = kbest[ kbest.map{ |k| k.scores[:rr] }.max_index ]
output.write "#{hypothesis.s}\n"
end
@@ -279,7 +270,7 @@ def main
elapsed = Time.now - start
STDERR.write"#{elapsed.round 2} s, #{(elapsed/Float(sz)).round 2} s per kbest; model size: #{model.size}\n\n" if !silent
- write_model(output_model, model) if output_model
+ WriteFile.write model.to_s+"\n", output_model if output_model
output.close
end
diff --git a/example/example.ini b/example/example.ini
index 46686bd..8c08a5d 100644
--- a/example/example.ini
+++ b/example/example.ini
@@ -7,7 +7,7 @@ ff_target_ngrams = 4 # 4 fix
ff_phrase_pairs = true # true /path/to/phrase_table
#filter_features = /path/to/target/stopwords_file
binary_feature_values = true
-iterate = 3
+iterate = 1
output = -
output_model = /dev/null