diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | decoder/bottom_up_parser-rs.cc | 29 | ||||
-rw-r--r-- | decoder/bottom_up_parser.cc | 43 | ||||
-rw-r--r-- | decoder/csplit.cc | 2 | ||||
-rw-r--r-- | decoder/grammar_test.cc | 4 | ||||
-rw-r--r-- | decoder/hg_io.cc | 90 | ||||
-rw-r--r-- | decoder/hg_test.cc | 6 | ||||
-rw-r--r-- | decoder/lattice.cc | 2 | ||||
-rw-r--r-- | decoder/lattice.h | 7 | ||||
-rw-r--r-- | decoder/parser_test.cc | 4 | ||||
-rw-r--r-- | tests/system_tests/lattice/gold.statistics | 7 | ||||
-rw-r--r-- | tests/system_tests/lattice/gold.stdout | 14 | ||||
-rw-r--r-- | tests/system_tests/lattice/input.txt | 1 | ||||
-rw-r--r-- | tests/system_tests/lattice/lattice.scfg | 1 | ||||
-rw-r--r-- | tests/system_tests/lattice/weights | 3 | ||||
-rw-r--r-- | utils/Makefile.am | 6 | ||||
-rw-r--r-- | utils/dedup_corpus.cc | 21 | ||||
-rw-r--r-- | utils/hash.h | 3 |
18 files changed, 165 insertions, 79 deletions
@@ -1,3 +1,4 @@ +utils/dedup_corpus klm/lm/builder/dump_counts klm/util/cat_compressed example_extff/ff_example.lo diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc index fbde7e24..863d7e2f 100644 --- a/decoder/bottom_up_parser-rs.cc +++ b/decoder/bottom_up_parser-rs.cc @@ -35,7 +35,7 @@ class RSChart { const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost); + const SparseVector<double>& lattice_feats); // returns true if a new node was added to the chart // false otherwise @@ -43,7 +43,7 @@ class RSChart { const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost); + const SparseVector<double>& lattice_feats); void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); void TopoSortUnaries(); @@ -69,9 +69,9 @@ WordID RSChart::kGOAL = 0; // "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" struct RSActiveItem { explicit RSActiveItem(const GrammarIter* g, int i) : - gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {} - void ExtendTerminal(int symbol, float src_cost) { - lattice_cost += src_cost; + gptr_(g), ant_nodes_(), lattice_feats(), i_(i) {} + void ExtendTerminal(int symbol, const SparseVector<double>& src_feats) { + lattice_feats += src_feats; if (symbol != kEPS) gptr_ = gptr_->Extend(symbol); } @@ -85,7 +85,7 @@ struct RSActiveItem { } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - float lattice_cost; // TODO: use SparseVector<double> to encode input features + SparseVector<double> lattice_feats; short i_; }; @@ -174,7 +174,7 @@ bool RSChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; @@ -182,8 +182,7 @@ bool RSChart::ApplyRule(const int i, new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); - if (lattice_cost && lc_fid_) - new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + new_edge->feature_values_ += lattice_feats; Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -211,7 +210,7 @@ void RSChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { const int n = rules->GetNumRules(); //cerr << i << " " << j << ": NUM RULES: " << n << endl; for (int k = 0; k < n; ++k) { @@ -219,7 +218,7 @@ void RSChart::ApplyRules(const int i, TRulePtr rule = rules->GetIthRule(k); // apply rule, and if we create a new node, apply any necessary // unary rules - if (ApplyRule(i, j, rule, tail, lattice_cost)) { + if (ApplyRule(i, j, rule, tail, lattice_feats)) { unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; ApplyUnaryRules(i, j, rule->lhs_, nodeidx); } @@ -233,7 +232,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig //cerr << " --MATCH\n"; WordID new_lhs = unaries_[ri]->GetLHS(); const Hypergraph::TailNodeVector ant(1, nodeidx); - if (ApplyRule(i, j, unaries_[ri], ant, 0)) { + if (ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>())) { //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; unsigned nodeidx = nodemap_(i,j)[new_lhs]; ApplyUnaryRules(i, j, new_lhs, nodeidx); @@ -245,7 +244,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { // deal with completed rules const RuleBin* rb = x.gptr_->GetRules(); - if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost); + if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_feats); //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; // continue looking for extensions of the rule to the right @@ -264,7 +263,7 @@ void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { if (in_edge.dist2next == check_edge_len) { //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; RSActiveItem copy = x; - copy.ExtendTerminal(in_edge.label, in_edge.cost); + copy.ExtendTerminal(in_edge.label, in_edge.features); if (copy) AddToChart(copy, i, k); } } @@ -306,7 +305,7 @@ bool RSChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant, 0); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>()); } } if (!SILENT) cerr << endl; diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index b30f1ec6..7ce8e09d 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -38,13 +38,13 @@ class PassiveChart { const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost); + const SparseVector<double>& lattice_feats); void ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost); + const SparseVector<double>& lattice_feats); void ApplyUnaryRules(const int i, const int j); void TopoSortUnaries(); @@ -59,7 +59,6 @@ class PassiveChart { const WordID goal_cat_; // category that is being searched for at [0,n] TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found - const int lc_fid_; vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars static WordID kGOAL; // [Goal] @@ -74,18 +73,18 @@ class ActiveChart { act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} struct ActiveItem { - ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) : - gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, const SparseVector<double>& lfeats) : + gptr_(g), ant_nodes_(a), lattice_feats(lfeats) {} explicit ActiveItem(const GrammarIter* g) : - gptr_(g), ant_nodes_(), lattice_cost(0.0) {} + gptr_(g), ant_nodes_(), lattice_feats() {} - void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const { + void ExtendTerminal(int symbol, const SparseVector<double>& src_feats, vector<ActiveItem>* out_cell) const { if (symbol == kEPS) { - out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_cost + src_cost)); + out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_feats + src_feats)); } else { const GrammarIter* ni = gptr_->Extend(symbol); if (ni) - out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); + out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_feats + src_feats)); } } void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const { @@ -96,12 +95,12 @@ class ActiveChart { for (unsigned i = 0; i < ant_nodes_.size(); ++i) na[i] = ant_nodes_[i]; na[ant_nodes_.size()] = node_index; - out_cell->push_back(ActiveItem(ni, na, lattice_cost)); + out_cell->push_back(ActiveItem(ni, na, lattice_feats)); } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - float lattice_cost; // TODO? use SparseVector<double> + SparseVector<double> lattice_feats; }; inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); } @@ -134,12 +133,12 @@ class ActiveChart { for (vector<LatticeArc>::const_iterator ai = out_arcs.begin(); ai != out_arcs.end(); ++ai) { const WordID& f = ai->label; - const double& c = ai->cost; + const SparseVector<double>& c = ai->features; const int& len = ai->dist2next; //cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl; const vector<ActiveItem>& ec = act_chart_(i, j-1); //cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl; - //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_cost << endl; } + //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_feats << endl; } for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di) di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); } @@ -163,7 +162,6 @@ PassiveChart::PassiveChart(const string& goal, goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), goal_idx_(-1), - lc_fid_(FD::Convert("LatticeCost")), unaries_() { act_chart_.resize(grammars_.size()); for (unsigned i = 0; i < grammars_.size(); ++i) { @@ -232,7 +230,7 @@ void PassiveChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; @@ -240,8 +238,7 @@ void PassiveChart::ApplyRule(const int i, new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); - if (lattice_cost && lc_fid_) - new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + new_edge->feature_values_ += lattice_feats; Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -265,25 +262,25 @@ void PassiveChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { const int n = rules->GetNumRules(); //cerr << i << " " << j << ": NUM RULES: " << n << endl; for (int k = 0; k < n; ++k) { //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; - ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost); + ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_feats); } } void PassiveChart::ApplyUnaryRules(const int i, const int j) { const vector<int>& nodes = chart_(i,j); // reference is important! for (unsigned di = 0; di < nodes.size(); ++di) { - const WordID& cat = forest_->nodes_[nodes[di]].cat_; + const WordID cat = forest_->nodes_[nodes[di]].cat_; for (unsigned ri = 0; ri < unaries_.size(); ++ri) { //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; if (unaries_[ri]->f()[0] == cat) { //cerr << " --MATCH\n"; const Hypergraph::TailNodeVector ant(1, nodes[di]); - ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes + ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>()); // may update nodes } } } @@ -313,7 +310,7 @@ bool PassiveChart::Parse() { ai != cell.end(); ++ai) { const RuleBin* rules = (ai->gptr_->GetRules()); if (!rules) continue; - ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost); + ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_feats); } } } @@ -331,7 +328,7 @@ bool PassiveChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant, 0); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>()); } } } diff --git a/decoder/csplit.cc b/decoder/csplit.cc index 7ee4092e..7a6ed102 100644 --- a/decoder/csplit.cc +++ b/decoder/csplit.cc @@ -150,7 +150,7 @@ bool CompoundSplit::TranslateImpl(const string& input, SplitUTF8String(input, &in); smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign for (int i = 0; i < in.size(); ++i) - smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), 0.0, 1))); + smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), SparseVector<double>(), 1))); smeta->ComputeInputLatticeType(); pimpl_->BuildTrellis(in, forest); forest->Reweight(weights); diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 69240139..11df1e3f 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -47,8 +47,8 @@ BOOST_AUTO_TEST_CASE(TestTextGrammarFile) { GrammarPtr g(new TextGrammar(path + "/grammar.prune")); vector<GrammarPtr> grammars(1, g); - LatticeArc a(TD::Convert("ein"), 0.0, 1); - LatticeArc b(TD::Convert("haus"), 0.0, 1); + LatticeArc a(TD::Convert("ein"), SparseVector<double>(), 1); + LatticeArc b(TD::Convert("haus"), SparseVector<double>(), 1); Lattice lattice(2); lattice[0].push_back(a); lattice[1].push_back(b); diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index 626b2954..71f50a29 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -87,6 +87,12 @@ string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses return os.str(); } +// TODO this should write out the PLF with the Python dictionary format +// rather than just the single "LatticeCost" feature +double PLFFeatureDictionary(const SparseVector<double>& f) { + return f.get(FD::Convert("LatticeCost")); +} + string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) { static bool first = true; if (first) { InitEscapes(); first = false; } @@ -99,7 +105,7 @@ string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) os << '('; for (int j = 0; j < arcs.size(); ++j) { os << "('" << Escape(TD::Convert(arcs[j].label)) << "'," - << arcs[j].cost << ',' << arcs[j].dist2next << "),"; + << PLFFeatureDictionary(arcs[j].features) << ',' << arcs[j].dist2next << "),"; } os << "),"; } @@ -128,7 +134,10 @@ inline void eatws(const std::string& in, int& c) { std::string getEscapedString(const std::string& in, int &c) { eatws(in,c); - if (get(in,c++) != quote) return "ERROR"; + if (get(in,c++) != quote) { + cerr << "Expected escaped string to begin with " << quote << ". Got " << get(in, c - 1) << "\n"; + abort(); + } std::string res; char cur = 0; do { @@ -146,7 +155,7 @@ float getFloat(const std::string& in, int &c) { std::string tmp; eatws(in,c); - while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { + while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',' && get(in,c) != '}') { tmp += get(in,c++); } eatws(in,c); @@ -171,7 +180,18 @@ int getInt(const std::string& in, int &c) // maximum number of nodes permitted #define MAX_NODES 100000000 -// parse ('foo', 0.23) + +void ReadPLFFeature(const std::string& in, int &c, map<string, float>& features) { + eatws(in,c); + string name = getEscapedString(in,c); + eatws(in,c); + if (get(in,c++) != ':') { cerr << "PCN/PLF parse error: expected : after feature name " << name << "\n"; abort(); } + float value = getFloat(in, c); + eatws(in,c); + features[name] = value; +} + +// parse ('foo', 0.23, 1) void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { if (get(in,c++) != '(') { cerr << "PCN/PLF parse error: expected (\n"; abort(); } vector<WordID> ewords(2, 0); @@ -180,22 +200,49 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { r->ComputeArity(); // cerr << "RULE: " << r->AsString() << endl; if (get(in,c++) != ',') { cerr << in << endl; cerr << "PCN/PLF parse error: expected , after string\n"; abort(); } + eatws(in,c); + + map<string, float> features; size_t cnNext = 1; - std::vector<float> probs; - probs.push_back(getFloat(in,c)); - while (get(in,c) == ',') { + // Read in sparse feature format + if (get(in,c) == '{') { c++; - float val = getFloat(in,c); - probs.push_back(val); - // cerr << val << endl; //REMO + eatws(in,c); + if (get(in,c) != '}') { + ReadPLFFeature(in, c, features); + } + while (get(in,c) == ',') { + c++; + if (get(in,c) == '}') { break; } + ReadPLFFeature(in, c, features); + } + if (get(in,c++) != '}') { cerr << "PCN/PLF parse error: expected } after feature dictionary\n"; abort(); } + eatws(in,c); + if (get(in, c++) != ',') { cerr << "PCN/PLF parse error: expected , after feature dictionary\n"; abort(); } + cnNext = static_cast<size_t>(getFloat(in, c)); } - //if we read more than one prob, this was a lattice, last item was column increment - if (probs.size()>1) { + // Read in dense feature format + else { + std::vector<float> probs; + probs.push_back(getFloat(in,c)); + while (get(in,c) == ',') { + c++; + float val = getFloat(in,c); + probs.push_back(val); + // cerr << val << endl; //REMO + } + if (probs.size() == 0) { cerr << "PCN/PLF parse error: missing destination state increment\n"; abort(); } + + // the last item was column increment cnNext = static_cast<size_t>(probs.back()); probs.pop_back(); - if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); } + + for (unsigned i = 0; i < probs.size(); ++i) { + features["LatticeCost_" + to_string(i)] = probs[i]; + } } - if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block\n"; abort(); } + if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block. Got " << get(in, c-1) << "\n"; abort(); } + if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); } eatws(in,c); Hypergraph::TailNodeVector tail(1, cur_node); Hypergraph::Edge* edge = hg->AddEdge(r, tail); @@ -204,15 +251,15 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); } hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]); - for (int i = 0; i < probs.size(); ++i) - edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]); + for (map<string, float>::iterator it = features.begin(); it != features.end(); ++it) { + edge->feature_values_.set_value(FD::Convert(it->first), it->second); + } } -// parse (('foo', 0.23), ('bar', 0.77)) +// parse (('foo', 0.23, 1), ('bar', 0.77, 1)) void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) { - //cerr << "PLF READING NODE " << cur_node << endl; if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); } - if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); } + if (get(in,c++) != '(') { cerr << line << ": Syntax error 1 in PLF\n"; abort(); } eatws(in,c); while (1) { if (c > (int)in.size()) { break; } @@ -237,7 +284,7 @@ void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) hg->clear(); int c = 0; int cur_node = 0; - if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); } + if (in[c++] != '(') { cerr << line << ": Syntax error in PLF!\n"; abort(); } while (1) { if (c > (int)in.size()) { break; } if (PLF::get(in,c) == ')') { @@ -263,7 +310,6 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { ReadFromPLF(plf, &g, 0); const int num_nodes = g.nodes_.size() - 1; l.resize(num_nodes); - int fid0=FD::Convert("Feature_0"); for (int i = 0; i < num_nodes; ++i) { vector<LatticeArc>& alts = l[i]; const Hypergraph::Node& node = g.nodes_[i]; @@ -272,7 +318,7 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { for (int j = 0; j < num_alts; ++j) { const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]]; alts[j].label = edge.rule_->e_[1]; - alts[j].cost = edge.feature_values_.get(fid0); + alts[j].features = edge.feature_values_; alts[j].dist2next = edge.head_node_ - node.id_; } } diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 366b269d..a597ad8d 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -214,8 +214,8 @@ BOOST_AUTO_TEST_CASE(TestIntersect) { BOOST_CHECK_EQUAL(4, best); Lattice target(2); - target[0].push_back(LatticeArc(TD::Convert("a"), 0.0, 1)); - target[1].push_back(LatticeArc(TD::Convert("b"), 0.0, 1)); + target[0].push_back(LatticeArc(TD::Convert("a"), SparseVector<double>(), 1)); + target[1].push_back(LatticeArc(TD::Convert("b"), SparseVector<double>(), 1)); HG::Intersect(target, &hg); hg.PrintGraphviz(); } @@ -256,7 +256,7 @@ BOOST_AUTO_TEST_CASE(PLF) { string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; HypergraphIO::ReadFromPLF(inplf, &hg); SparseVector<double> wts; - wts.set_value(FD::Convert("Feature_0"), 1.0); + wts.set_value(FD::Convert("LatticeCost_0"), 1.0); hg.Reweight(wts); hg.PrintGraphviz(); string outplf = HypergraphIO::AsPLF(hg); diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 1f97048d..2740ce63 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -49,7 +49,7 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { l.clear(); l.resize(ids.size()); for (int i = 0; i < l.size(); ++i) - l[i].push_back(LatticeArc(ids[i], 0.0, 1)); + l[i].push_back(LatticeArc(ids[i], SparseVector<double>(), 1)); } void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { diff --git a/decoder/lattice.h b/decoder/lattice.h index 1258d3f5..469615b5 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -3,6 +3,7 @@ #include <string> #include <vector> +#include "sparse_vector.h" #include "wordid.h" #include "array2d.h" @@ -15,10 +16,10 @@ struct LatticeTools { struct LatticeArc { WordID label; - double cost; + SparseVector<double> features; int dist2next; - LatticeArc() : label(), cost(), dist2next() {} - LatticeArc(WordID w, double c, int i) : label(w), cost(c), dist2next(i) {} + LatticeArc() : label(), features(), dist2next() {} + LatticeArc(WordID w, const SparseVector<double>& f, int i) : label(w), features(f), dist2next(i) {} }; class Lattice : public std::vector<std::vector<LatticeArc> > { diff --git a/decoder/parser_test.cc b/decoder/parser_test.cc index e2916e44..37c45cc8 100644 --- a/decoder/parser_test.cc +++ b/decoder/parser_test.cc @@ -10,8 +10,8 @@ using namespace std; BOOST_AUTO_TEST_CASE(Parse) { - LatticeArc a(TD::Convert("ein"), 0.0, 1); - LatticeArc b(TD::Convert("haus"), 0.0, 1); + LatticeArc a(TD::Convert("ein"), SparseVector<double>(), 1); + LatticeArc b(TD::Convert("haus"), SparseVector<double>(), 1); Lattice lattice(2); lattice[0].push_back(a); lattice[1].push_back(b); diff --git a/tests/system_tests/lattice/gold.statistics b/tests/system_tests/lattice/gold.statistics index 302ddf14..4ce55900 100644 --- a/tests/system_tests/lattice/gold.statistics +++ b/tests/system_tests/lattice/gold.statistics @@ -5,3 +5,10 @@ +lm_edges 10 +lm_paths 5 +lm_trans ab +-lm_nodes 3 +-lm_edges 6 +-lm_paths 4 ++lm_nodes 3 ++lm_edges 6 ++lm_paths 4 ++lm_trans d' diff --git a/tests/system_tests/lattice/gold.stdout b/tests/system_tests/lattice/gold.stdout index 1adb51f1..dd2a2943 100644 --- a/tests/system_tests/lattice/gold.stdout +++ b/tests/system_tests/lattice/gold.stdout @@ -1,5 +1,9 @@ -0 ||| ab ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost=0.125 ||| -1.09359 -0 ||| cb ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.3 LatticeCost=2.25 ||| -3.85288 -0 ||| a_b ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.2 LatticeCost=2.5 ||| -4.00288 -0 ||| a' b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost=2.5 ||| -4.53718 -0 ||| a b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost=2.5 ||| -4.53718 +0 ||| ab ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.125 ||| -1.09359 +0 ||| cb ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.3 LatticeCost_0=2.25 ||| -3.85288 +0 ||| a_b ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.2 LatticeCost_0=2.5 ||| -4.00288 +0 ||| a' b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost_0=2.5 ||| -4.53718 +0 ||| a b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost_0=2.5 ||| -4.53718 +1 ||| d' ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost0=-0.1 LatticeCost_0=0.1 UsesDPrime=1 ||| 999.031 +1 ||| b ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.2 ||| -1.06859 +1 ||| a' ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.5 ||| -1.46859 +1 ||| a ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.5 ||| -1.46859 diff --git a/tests/system_tests/lattice/input.txt b/tests/system_tests/lattice/input.txt index e0cd1b57..17bfd47c 100644 --- a/tests/system_tests/lattice/input.txt +++ b/tests/system_tests/lattice/input.txt @@ -1 +1,2 @@ ((('A',0.5,1),('C',0.25,1),('AB',0.125,2),),(('B',2,1),),) +((('A',0.5,1),('D\'',{'LatticeCost_0':0.1, 'UsesDPrime':1.0,},1),('B', 1)),) diff --git a/tests/system_tests/lattice/lattice.scfg b/tests/system_tests/lattice/lattice.scfg index 87a72383..04fe0cf0 100644 --- a/tests/system_tests/lattice/lattice.scfg +++ b/tests/system_tests/lattice/lattice.scfg @@ -4,3 +4,4 @@ [X] ||| AB ||| ab ||| Cost=0.1 [X] ||| C B ||| cb ||| Cost=0.3 [X] ||| A B ||| a_b ||| Cost=0.2 +[X] ||| D' ||| d' ||| Cost0=-0.1 diff --git a/tests/system_tests/lattice/weights b/tests/system_tests/lattice/weights index cb59b27b..7e7d0fa8 100644 --- a/tests/system_tests/lattice/weights +++ b/tests/system_tests/lattice/weights @@ -2,4 +2,5 @@ WordPenalty 1 SourceWordPenalty 1 Glue 0 Cost -1 -LatticeCost -1 +LatticeCost_0 -1 +UsesDPrime 1000 diff --git a/utils/Makefile.am b/utils/Makefile.am index dd74ddc0..c858ac7e 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -1,4 +1,4 @@ -bin_PROGRAMS = reconstruct_weights atools +bin_PROGRAMS = reconstruct_weights atools dedup_corpus noinst_PROGRAMS = \ ts \ @@ -98,6 +98,10 @@ atools_SOURCES = atools.cc atools_LDADD = libutils.a atools_LDFLAGS = $(STATIC_FLAGS) +dedup_corpus_SOURCES = dedup_corpus.cc +dedup_corpus_LDADD = libutils.a +dedup_corpus_LDFLAGS = $(STATIC_FLAGS) + phmt_SOURCES = phmt.cc phmt_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) ts_SOURCES = ts.cc diff --git a/utils/dedup_corpus.cc b/utils/dedup_corpus.cc new file mode 100644 index 00000000..818a6ffa --- /dev/null +++ b/utils/dedup_corpus.cc @@ -0,0 +1,21 @@ +#include <iostream> +#include "hash.h" + +using namespace std; + +#define INITIAL_SIZE 20000000 + +int main(int argc, char **argv) { + if (argc != 1) { + cerr << "Usage: " << argv[0] << " < file.txt\n"; + return 1; + } + SPARSE_HASH_SET<uint64_t> seen(INITIAL_SIZE); + string line; + while(getline(cin, line)) { + uint64_t h = cdec::MurmurHash3_64(&line[0], line.size(), 17); + if (seen.insert(h).second) + cout << line << '\n'; + } +} + diff --git a/utils/hash.h b/utils/hash.h index 24d2b6ad..7de4db6d 100644 --- a/utils/hash.h +++ b/utils/hash.h @@ -13,7 +13,9 @@ # include <sparsehash/dense_hash_map> # include <sparsehash/dense_hash_set> # include <sparsehash/sparse_hash_map> +# include <sparsehash/sparse_hash_set> # define SPARSE_HASH_MAP google::sparse_hash_map +# define SPARSE_HASH_SET google::sparse_hash_set # define HASH_MAP google::dense_hash_map # define HASH_SET google::dense_hash_set # define HASH_MAP_DELETED(h,deleted) do { (h).set_deleted_key(deleted); } while(0) @@ -29,6 +31,7 @@ namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; } #endif # define SPARSE_HASH_MAP std::unordered_map +# define SPARSE_HASH_SET std::unordered_set # define HASH_MAP std::unordered_map # define HASH_SET std::unordered_set # define HASH_MAP_DELETED(h,deleted) |