summaryrefslogtreecommitdiff
path: root/dpmert/mert_geometry.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-01-27 14:49:08 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-01-27 14:49:08 -0500
commit89b662b51373f0f466d62a65d3f0a164d1d31b1c (patch)
treeefff9d60e5c128c054f848fa802e9f077cad54b1 /dpmert/mert_geometry.cc
parent3d17bf9ae1ba67cd091794839d4d5f4c393a0e2c (diff)
rename vest to dpmert (dynamic programming mert), rename variables and types to correspond to standard geometric concepts
Diffstat (limited to 'dpmert/mert_geometry.cc')
-rw-r--r--dpmert/mert_geometry.cc186
1 files changed, 186 insertions, 0 deletions
diff --git a/dpmert/mert_geometry.cc b/dpmert/mert_geometry.cc
new file mode 100644
index 00000000..81b25af9
--- /dev/null
+++ b/dpmert/mert_geometry.cc
@@ -0,0 +1,186 @@
+#include "mert_geometry.h"
+
+#include <cassert>
+#include <limits>
+
+using namespace std;
+using boost::shared_ptr;
+
+ConvexHull::ConvexHull(int i) {
+ if (i == 0) {
+ // do nothing - <>
+ } else if (i == 1) {
+ points.push_back(shared_ptr<MERTPoint>(new MERTPoint(0, 0, 0, shared_ptr<MERTPoint>(), shared_ptr<MERTPoint>())));
+ assert(this->IsMultiplicativeIdentity());
+ } else {
+ cerr << "Only can create ConvexHull semiring 0 and 1 with this constructor!\n";
+ abort();
+ }
+}
+
+const ConvexHull ConvexHullWeightFunction::operator()(const Hypergraph::Edge& e) const {
+ const double m = direction.dot(e.feature_values_);
+ const double b = origin.dot(e.feature_values_);
+ MERTPoint* point = new MERTPoint(m, b, e);
+ return ConvexHull(1, point);
+}
+
+ostream& operator<<(ostream& os, const ConvexHull& env) {
+ os << '<';
+ const vector<shared_ptr<MERTPoint> >& points = env.GetSortedSegs();
+ for (int i = 0; i < points.size(); ++i)
+ os << (i==0 ? "" : "|") << "x=" << points[i]->x << ",b=" << points[i]->b << ",m=" << points[i]->m << ",p1=" << points[i]->p1 << ",p2=" << points[i]->p2;
+ return os << '>';
+}
+
+#define ORIGINAL_MERT_IMPLEMENTATION 1
+#ifdef ORIGINAL_MERT_IMPLEMENTATION
+
+struct SlopeCompare {
+ bool operator() (const shared_ptr<MERTPoint>& a, const shared_ptr<MERTPoint>& b) const {
+ return a->m < b->m;
+ }
+};
+
+const ConvexHull& ConvexHull::operator+=(const ConvexHull& other) {
+ if (!other.is_sorted) other.Sort();
+ if (points.empty()) {
+ points = other.points;
+ return *this;
+ }
+ is_sorted = false;
+ int j = points.size();
+ points.resize(points.size() + other.points.size());
+ for (int i = 0; i < other.points.size(); ++i)
+ points[j++] = other.points[i];
+ assert(j == points.size());
+ return *this;
+}
+
+void ConvexHull::Sort() const {
+ sort(points.begin(), points.end(), SlopeCompare());
+ const int k = points.size();
+ int j = 0;
+ for (int i = 0; i < k; ++i) {
+ MERTPoint l = *points[i];
+ l.x = kMinusInfinity;
+ // cerr << "m=" << l.m << endl;
+ if (0 < j) {
+ if (points[j-1]->m == l.m) { // lines are parallel
+ if (l.b <= points[j-1]->b) continue;
+ --j;
+ }
+ while(0 < j) {
+ l.x = (l.b - points[j-1]->b) / (points[j-1]->m - l.m);
+ if (points[j-1]->x < l.x) break;
+ --j;
+ }
+ if (0 == j) l.x = kMinusInfinity;
+ }
+ *points[j++] = l;
+ }
+ points.resize(j);
+ is_sorted = true;
+}
+
+const ConvexHull& ConvexHull::operator*=(const ConvexHull& other) {
+ if (other.IsMultiplicativeIdentity()) { return *this; }
+ if (this->IsMultiplicativeIdentity()) { (*this) = other; return *this; }
+
+ if (!is_sorted) Sort();
+ if (!other.is_sorted) other.Sort();
+
+ if (this->IsEdgeEnvelope()) {
+// if (other.size() > 1)
+// cerr << *this << " (TIMES) " << other << endl;
+ shared_ptr<MERTPoint> edge_parent = points[0];
+ const double& edge_b = edge_parent->b;
+ const double& edge_m = edge_parent->m;
+ points.clear();
+ for (int i = 0; i < other.points.size(); ++i) {
+ const MERTPoint& p = *other.points[i];
+ const double m = p.m + edge_m;
+ const double b = p.b + edge_b;
+ const double& x = p.x; // x's don't change with *
+ points.push_back(shared_ptr<MERTPoint>(new MERTPoint(x, m, b, edge_parent, other.points[i])));
+ assert(points.back()->p1->edge);
+ }
+// if (other.size() > 1)
+// cerr << " = " << *this << endl;
+ } else {
+ vector<shared_ptr<MERTPoint> > new_points;
+ int this_i = 0;
+ int other_i = 0;
+ const int this_size = points.size();
+ const int other_size = other.points.size();
+ double cur_x = kMinusInfinity; // moves from left to right across the
+ // real numbers, stopping for all inter-
+ // sections
+ double this_next_val = (1 < this_size ? points[1]->x : kPlusInfinity);
+ double other_next_val = (1 < other_size ? other.points[1]->x : kPlusInfinity);
+ while (this_i < this_size && other_i < other_size) {
+ const MERTPoint& this_point = *points[this_i];
+ const MERTPoint& other_point= *other.points[other_i];
+ const double m = this_point.m + other_point.m;
+ const double b = this_point.b + other_point.b;
+
+ new_points.push_back(shared_ptr<MERTPoint>(new MERTPoint(cur_x, m, b, points[this_i], other.points[other_i])));
+ int comp = 0;
+ if (this_next_val < other_next_val) comp = -1; else
+ if (this_next_val > other_next_val) comp = 1;
+ if (0 == comp) { // the next values are equal, advance both indices
+ ++this_i;
+ ++other_i;
+ cur_x = this_next_val; // could be other_next_val (they're equal!)
+ this_next_val = (this_i+1 < this_size ? points[this_i+1]->x : kPlusInfinity);
+ other_next_val = (other_i+1 < other_size ? other.points[other_i+1]->x : kPlusInfinity);
+ } else { // advance the i with the lower x, update cur_x
+ if (-1 == comp) {
+ ++this_i;
+ cur_x = this_next_val;
+ this_next_val = (this_i+1 < this_size ? points[this_i+1]->x : kPlusInfinity);
+ } else {
+ ++other_i;
+ cur_x = other_next_val;
+ other_next_val = (other_i+1 < other_size ? other.points[other_i+1]->x : kPlusInfinity);
+ }
+ }
+ }
+ points.swap(new_points);
+ }
+ //cerr << "Multiply: result=" << (*this) << endl;
+ return *this;
+}
+
+// recursively construct translation
+void MERTPoint::ConstructTranslation(vector<WordID>* trans) const {
+ const MERTPoint* cur = this;
+ vector<vector<WordID> > ant_trans;
+ while(!cur->edge) {
+ ant_trans.resize(ant_trans.size() + 1);
+ cur->p2->ConstructTranslation(&ant_trans.back());
+ cur = cur->p1.get();
+ }
+ size_t ant_size = ant_trans.size();
+ vector<const vector<WordID>*> pants(ant_size);
+ assert(ant_size == cur->edge->tail_nodes_.size());
+ --ant_size;
+ for (int i = 0; i < pants.size(); ++i) pants[ant_size - i] = &ant_trans[i];
+ cur->edge->rule_->ESubstitute(pants, trans);
+}
+
+void MERTPoint::CollectEdgesUsed(std::vector<bool>* edges_used) const {
+ if (edge) {
+ assert(edge->id_ < edges_used->size());
+ (*edges_used)[edge->id_] = true;
+ }
+ if (p1) p1->CollectEdgesUsed(edges_used);
+ if (p2) p2->CollectEdgesUsed(edges_used);
+}
+
+#else
+
+// THIS IS THE NEW FASTER IMPLEMENTATION OF THE MERT SEMIRING OPERATIONS
+
+#endif
+