summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-07-01 13:49:47 -0400
committerChris Dyer <redpony@gmail.com>2014-07-01 13:49:47 -0400
commitf3255478edc9b04fd13d4e74a2fbf37aeb981202 (patch)
treec7fb4875c55abcc411c2557255e5203c3b7aa15c
parent9a9abc5f6e9b3f26daf5f276434c1fd7f0c83da2 (diff)
track spans in t2s translation
-rw-r--r--configure.ac7
-rw-r--r--decoder/tree2string_translator.cc4
-rw-r--r--decoder/tree_fragment.cc17
-rw-r--r--decoder/tree_fragment.h7
-rw-r--r--training/dtrain/Makefile.am2
5 files changed, 31 insertions, 6 deletions
diff --git a/configure.ac b/configure.ac
index 6b128768..d7ced0ea 100644
--- a/configure.ac
+++ b/configure.ac
@@ -1,5 +1,5 @@
AC_CONFIG_MACRO_DIR([m4])
-AC_INIT([cdec],[2014-01-28])
+AC_INIT([cdec],[2014-06-15])
AC_CONFIG_SRCDIR([decoder/cdec.cc])
AM_INIT_AUTOMAKE
AC_CONFIG_HEADERS(config.h)
@@ -19,6 +19,7 @@ BOOST_REQUIRE([1.44])
BOOST_FILESYSTEM
BOOST_PROGRAM_OPTIONS
BOOST_SYSTEM
+BOOST_REGEX
BOOST_SERIALIZATION
BOOST_TEST
BOOST_THREADS
@@ -178,9 +179,9 @@ fi
#BOOST_THREADS
CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS"
-LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SERIALIZATION_LDFLAGS $BOOST_SYSTEM_LDFLAGS $BOOST_FILESYSTEM_LDFLAGS"
+LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_REGEX_LDFLAGS $BOOST_SERIALIZATION_LDFLAGS $BOOST_SYSTEM_LDFLAGS $BOOST_FILESYSTEM_LDFLAGS"
# $BOOST_THREAD_LDFLAGS"
-LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SERIALIZATION_LIBS $BOOST_SYSTEM_LIBS $BOOST_FILESYSTEM_LIBS $ZLIBS"
+LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_REGEX_LIBS $BOOST_SERIALIZATION_LIBS $BOOST_SYSTEM_LIBS $BOOST_FILESYSTEM_LIBS $ZLIBS"
# $BOOST_THREAD_LIBS"
AC_CHECK_HEADER(google/dense_hash_map,
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index c9c91a37..adc8dc89 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -332,7 +332,9 @@ struct Tree2StringTranslatorImpl {
assert(tail.size() == r->Arity());
HG::Edge* new_edge = hg.AddEdge(r, tail);
new_edge->feature_values_ = r->GetFeatureValues();
- // TODO: set i and j
+ auto& inspan = input_tree.nodes[s.task.input_node_idx].span;
+ new_edge->i_ = inspan.first;
+ new_edge->j_ = inspan.second;
hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]);
}
for (const auto& n : s.future_work) {
diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index aad0b2c4..42f7793a 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -28,12 +28,14 @@ TreeFragment::TreeFragment(const StringPiece& tree, bool allow_frontier_sites) {
unsigned cp = 0, symp = 0, np = 0;
ParseRec(tree, allow_frontier_sites, cp, symp, np, &cp, &symp, &np);
root = nodes.back().lhs;
+ if (!allow_frontier_sites) SetupSpansRec(open - 1, 0);
//cerr << "ROOT: " << TD::Convert(root & ALL_MASK) << endl;
//DebugRec(open - 1, &cerr); cerr << "\n";
}
void TreeFragment::DebugRec(unsigned cur, ostream* out) const {
*out << '(' << TD::Convert(nodes[cur].lhs & ALL_MASK);
+ // *out << "_{" << nodes[cur].span.first << ',' << nodes[cur].span.second << '}';
for (auto& x : nodes[cur].rhs) {
*out << ' ';
if (IsFrontier(x)) {
@@ -47,6 +49,21 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const {
*out << ')';
}
+// returns left + the number of terminals rooted at NT cur,
+int TreeFragment::SetupSpansRec(unsigned cur, int left) {
+ int right = left;
+ for (auto& x : nodes[cur].rhs) {
+ if (IsRHS(x)) {
+ right = SetupSpansRec(x & ALL_MASK, right);
+ } else {
+ ++right;
+ }
+ }
+ nodes[cur].span.first = left;
+ nodes[cur].span.second = right;
+ return right;
+}
+
// cp is the character index in the tree
// np keeps track of the nodes (nonterminals) that have been built
// symp keeps track of the terminal symbols that have been built
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index 8bb7251a..6b4842ee 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -43,9 +43,10 @@ inline bool IsTerminal(unsigned x) {
struct TreeFragmentProduction {
TreeFragmentProduction() {}
- TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {}
+ TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r), span(std::make_pair<short,short>(-1,-1)) {}
unsigned lhs;
std::vector<unsigned> rhs;
+ std::pair<short, short> span; // the span of the node (in input, or not set for rules)
};
// this data structure represents a tree or forest
@@ -76,6 +77,10 @@ class TreeFragment {
// np keeps track of the nodes (nonterminals) that have been built
// symp keeps track of the terminal symbols that have been built
void ParseRec(const StringPiece& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp);
+
+ // used by constructor to set up span indices for logging/alignment purposes
+ int SetupSpansRec(unsigned cur, int left);
+
public:
unsigned root;
unsigned char frontier_sites;
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am
index ecb6c128..844c790d 100644
--- a/training/dtrain/Makefile.am
+++ b/training/dtrain/Makefile.am
@@ -1,7 +1,7 @@
bin_PROGRAMS = dtrain
dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h
-dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lboost_regex
+dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval