diff options
author | Wu, Ke <wuke@cs.umd.edu> | 2014-10-07 18:44:05 -0400 |
---|---|---|
committer | Wu, Ke <wuke@cs.umd.edu> | 2014-10-07 18:44:05 -0400 |
commit | 8c26c195213805face566a6407597ba2a871a122 (patch) | |
tree | 378301ff345bf465f407f1447ad5fe126b3cd47c /utils/argument_reorder_model.cc | |
parent | 6c7bf8cf49db88ca47e5b08aa449032995736854 (diff) |
Move synutils under utils
Diffstat (limited to 'utils/argument_reorder_model.cc')
-rw-r--r-- | utils/argument_reorder_model.cc | 323 |
1 files changed, 323 insertions, 0 deletions
diff --git a/utils/argument_reorder_model.cc b/utils/argument_reorder_model.cc new file mode 100644 index 00000000..58886251 --- /dev/null +++ b/utils/argument_reorder_model.cc @@ -0,0 +1,323 @@ +/* + * argument_reorder_model.cc + * + * Created on: Dec 15, 2013 + * Author: lijunhui + */ + +#include <boost/program_options.hpp> +#include <fstream> + +#include "argument_reorder_model.h" +#include "synutils.h" +#include "tsuruoka_maxent.h" + +inline void fnPreparingTrainingdata(const char* pszFName, int iCutoff, + const char* pszNewFName) { + SFReader* pFReader = new STxtFileReader(pszFName); + char* pszLine = new char[100001]; + int iLen; + Map hashPredicate; + while (pFReader->fnReadNextLine(pszLine, &iLen)) { + if (iLen == 0) continue; + + vector<string> vecTerms; + SplitOnWhitespace(string(pszLine), &vecTerms); + + for (size_t i = 0; i < vecTerms.size() - 1; i++) { + Iterator iter = hashPredicate.find(vecTerms[i]); + if (iter == hashPredicate.end()) { + hashPredicate[vecTerms[i]] = 1; + + } else { + iter->second++; + } + } + } + delete pFReader; + + pFReader = new STxtFileReader(pszFName); + FILE* fpOut = fopen(pszNewFName, "w"); + while (pFReader->fnReadNextLine(pszLine, &iLen)) { + if (iLen == 0) continue; + + vector<string> vecTerms; + SplitOnWhitespace(string(pszLine), &vecTerms); + ostringstream ostr; + for (size_t i = 0; i < vecTerms.size() - 1; i++) { + Iterator iter = hashPredicate.find(vecTerms[i]); + assert(iter != hashPredicate.end()); + if (iter->second >= iCutoff) { + ostr << vecTerms[i] << " "; + } + } + if (ostr.str().length() > 0) { + ostr << vecTerms[vecTerms.size() - 1]; + fprintf(fpOut, "%s\n", ostr.str().c_str()); + } + } + fclose(fpOut); + delete pFReader; + + delete[] pszLine; +} + +struct SArgumentReorderTrainer { + SArgumentReorderTrainer( + const char* pszSRLFname, // source-side srl tree file name + const char* pszAlignFname, // alignment filename + const char* pszSourceFname, // source file name + const char* pszTargetFname, // target file name + const char* pszTopPredicateFname, // target file name + const char* pszInstanceFname, // training instance file name + const char* pszModelFname, // classifier model file name + int iCutoff) { + fnGenerateInstanceFiles(pszSRLFname, pszAlignFname, pszSourceFname, + pszTargetFname, pszTopPredicateFname, + pszInstanceFname); + + string strInstanceFname, strModelFname; + strInstanceFname = string(pszInstanceFname) + string(".left"); + strModelFname = string(pszModelFname) + string(".left"); + fnTraining(strInstanceFname.c_str(), strModelFname.c_str(), iCutoff); + strInstanceFname = string(pszInstanceFname) + string(".right"); + strModelFname = string(pszModelFname) + string(".right"); + fnTraining(strInstanceFname.c_str(), strModelFname.c_str(), iCutoff); + } + + ~SArgumentReorderTrainer() {} + + private: + void fnTraining(const char* pszInstanceFname, const char* pszModelFname, + int iCutoff) { + char* pszNewInstanceFName = new char[strlen(pszInstanceFname) + 50]; + if (iCutoff > 0) { + sprintf(pszNewInstanceFName, "%s.tmp", pszInstanceFname); + fnPreparingTrainingdata(pszInstanceFname, iCutoff, pszNewInstanceFName); + } else { + strcpy(pszNewInstanceFName, pszInstanceFname); + } + + Tsuruoka_Maxent* pMaxent = new Tsuruoka_Maxent(NULL); + pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname, 300); + delete pMaxent; + + if (strcmp(pszNewInstanceFName, pszInstanceFname) != 0) { + sprintf(pszNewInstanceFName, "rm %s.tmp", pszInstanceFname); + system(pszNewInstanceFName); + } + delete[] pszNewInstanceFName; + } + + void fnGenerateInstanceFiles( + const char* pszSRLFname, // source-side flattened parse tree file name + const char* pszAlignFname, // alignment filename + const char* pszSourceFname, // source file name + const char* pszTargetFname, // target file name + const char* pszTopPredicateFname, // top predicate file name (we only + // consider predicates with 100+ + // occurrences + const char* pszInstanceFname // training instance file name + ) { + SAlignmentReader* pAlignReader = new SAlignmentReader(pszAlignFname); + SSrlSentenceReader* pSRLReader = new SSrlSentenceReader(pszSRLFname); + STxtFileReader* pTxtSReader = new STxtFileReader(pszSourceFname); + STxtFileReader* pTxtTReader = new STxtFileReader(pszTargetFname); + + Map* pMapPredicate; + if (pszTopPredicateFname != NULL) + pMapPredicate = fnLoadTopPredicates(pszTopPredicateFname); + else + pMapPredicate = NULL; + + char* pszLine = new char[50001]; + + FILE* fpLeftOut, *fpRightOut; + sprintf(pszLine, "%s.left", pszInstanceFname); + fpLeftOut = fopen(pszLine, "w"); + sprintf(pszLine, "%s.right", pszInstanceFname); + fpRightOut = fopen(pszLine, "w"); + + // read sentence by sentence + SAlignment* pAlign; + SSrlSentence* pSRL; + SParsedTree* pTree; + int iSentNum = 0; + while ((pAlign = pAlignReader->fnReadNextAlignment()) != NULL) { + pSRL = pSRLReader->fnReadNextSrlSentence(); + assert(pSRL != NULL); + pTree = pSRL->m_pTree; + assert(pTxtSReader->fnReadNextLine(pszLine, NULL)); + vector<string> vecSTerms; + SplitOnWhitespace(string(pszLine), &vecSTerms); + assert(pTxtTReader->fnReadNextLine(pszLine, NULL)); + vector<string> vecTTerms; + SplitOnWhitespace(string(pszLine), &vecTTerms); + // vecTPOSTerms.size() == 0, given the case when an english sentence fails + // parsing + + if (pTree != NULL) { + for (size_t i = 0; i < pSRL->m_vecPred.size(); i++) { + SPredicate* pPred = pSRL->m_vecPred[i]; + if (strcmp(pTree->m_vecTerminals[pPred->m_iPosition] + ->m_ptParent->m_pszTerm, + "VA") == 0) + continue; + string strPred = + string(pTree->m_vecTerminals[pPred->m_iPosition]->m_pszTerm); + if (pMapPredicate != NULL) { + Map::iterator iter_map = pMapPredicate->find(strPred); + if (pMapPredicate != NULL && iter_map == pMapPredicate->end()) + continue; + } + + SPredicateItem* pPredItem = new SPredicateItem(pTree, pPred); + + vector<string> vecStrBlock; + for (size_t j = 0; j < pPredItem->vec_items_.size(); j++) { + SSRLItem* pItem1 = pPredItem->vec_items_[j]; + vecStrBlock.push_back(SArgumentReorderModel::fnGetBlockOutcome( + pItem1->tree_item_->m_iBegin, pItem1->tree_item_->m_iEnd, + pAlign)); + } + + vector<string> vecStrLeftReorderType; + vector<string> vecStrRightReorderType; + SArgumentReorderModel::fnGetReorderType( + pPredItem, pAlign, vecStrLeftReorderType, vecStrRightReorderType); + for (int j = 1; j < pPredItem->vec_items_.size(); j++) { + string strLeftOutcome, strRightOutcome; + strLeftOutcome = vecStrLeftReorderType[j - 1]; + strRightOutcome = vecStrRightReorderType[j - 1]; + ostringstream ostr; + SArgumentReorderModel::fnGenerateFeature(pTree, pPred, pPredItem, j, + vecStrBlock[j - 1], + vecStrBlock[j], ostr); + + // fprintf(stderr, "%s %s\n", ostr.str().c_str(), + // strOutcome.c_str()); + // fprintf(fpOut, "sentid=%d %s %s\n", iSentNum, ostr.str().c_str(), + // strOutcome.c_str()); + fprintf(fpLeftOut, "%s %s\n", ostr.str().c_str(), + strLeftOutcome.c_str()); + fprintf(fpRightOut, "%s %s\n", ostr.str().c_str(), + strRightOutcome.c_str()); + } + } + } + delete pSRL; + + delete pAlign; + iSentNum++; + + if (iSentNum % 100000 == 0) fprintf(stderr, "#%d\n", iSentNum); + } + delete[] pszLine; + + fclose(fpLeftOut); + fclose(fpRightOut); + + delete pAlignReader; + delete pSRLReader; + delete pTxtSReader; + delete pTxtTReader; + } + + Map* fnLoadTopPredicates(const char* pszTopPredicateFname) { + if (pszTopPredicateFname == NULL) return NULL; + + Map* pMapPredicate = new Map(); + STxtFileReader* pReader = new STxtFileReader(pszTopPredicateFname); + char* pszLine = new char[50001]; + int iNumCount = 0; + while (pReader->fnReadNextLine(pszLine, NULL)) { + if (pszLine[0] == '#') continue; + char* p = strchr(pszLine, ' '); + assert(p != NULL); + p[0] = '\0'; + p++; + int iCount = atoi(p); + if (iCount < 100) break; + (*pMapPredicate)[string(pszLine)] = iNumCount++; + } + delete pReader; + delete[] pszLine; + return pMapPredicate; + } +}; + +namespace po = boost::program_options; + +inline void print_options(std::ostream& out, + po::options_description const& opts) { + typedef std::vector<boost::shared_ptr<po::option_description> > Ds; + Ds const& ds = opts.options(); + out << '"'; + for (unsigned i = 0; i < ds.size(); ++i) { + if (i) out << ' '; + out << "--" << ds[i]->long_name(); + } + out << '"\n'; +} +inline string str(char const* name, po::variables_map const& conf) { + return conf[name].as<string>(); +} + +//--srl_file /scratch0/mt_exp/gale-align/gale-align.nw.srl.cn --align_file +///scratch0/mt_exp/gale-align/gale-align.nw.al --source_file +///scratch0/mt_exp/gale-align/gale-align.nw.cn --target_file +///scratch0/mt_exp/gale-align/gale-align.nw.en --instance_file +///scratch0/mt_exp/gale-align/gale-align.nw.argreorder.instance --model_prefix +///scratch0/mt_exp/gale-align/gale-align.nw.argreorder.model --feature_cutoff 2 +//--srl_file /scratch0/mt_exp/gale-ctb/gale-ctb.srl.cn --align_file +///scratch0/mt_exp/gale-ctb/gale-ctb.align --source_file +///scratch0/mt_exp/gale-ctb/gale-ctb.cn --target_file +///scratch0/mt_exp/gale-ctb/gale-ctb.en0 --instance_file +///scratch0/mt_exp/gale-ctb/gale-ctb.argreorder.instance --model_prefix +///scratch0/mt_exp/gale-ctb/gale-ctb.argreorder.model --feature_cutoff 2 +int main(int argc, char** argv) { + + po::options_description opts("Configuration options"); + opts.add_options()("srl_file", po::value<string>(), "srl file path (input)")( + "align_file", po::value<string>(), "Alignment file path (input)")( + "source_file", po::value<string>(), "Source text file path (input)")( + "target_file", po::value<string>(), "Target text file path (input)")( + "instance_file", po::value<string>(), "Instance file path (output)")( + "model_prefix", po::value<string>(), + "Model file path prefix (output): three files will be generated")( + "feature_cutoff", po::value<int>()->default_value(100), + "Feature cutoff threshold")("help", "produce help message"); + + po::variables_map vm; + if (argc) { + po::store(po::parse_command_line(argc, argv, opts), vm); + po::notify(vm); + } + + if (vm.count("help")) { + print_options(cout, opts); + return 1; + } + + if (!vm.count("srl_file") || !vm.count("align_file") || + !vm.count("source_file") || !vm.count("target_file") || + !vm.count("instance_file") || !vm.count("model_prefix")) { + print_options(cout, opts); + if (!vm.count("parse_file")) cout << "--parse_file NOT FOUND\n"; + if (!vm.count("align_file")) cout << "--align_file NOT FOUND\n"; + if (!vm.count("source_file")) cout << "--source_file NOT FOUND\n"; + if (!vm.count("target_file")) cout << "--target_file NOT FOUND\n"; + if (!vm.count("instance_file")) cout << "--instance_file NOT FOUND\n"; + if (!vm.count("model_prefix")) cout << "--model_prefix NOT FOUND\n"; + exit(0); + } + + SArgumentReorderTrainer* pTrainer = new SArgumentReorderTrainer( + str("srl_file", vm).c_str(), str("align_file", vm).c_str(), + str("source_file", vm).c_str(), str("target_file", vm).c_str(), NULL, + str("instance_file", vm).c_str(), str("model_prefix", vm).c_str(), + vm["feature_cutoff"].as<int>()); + delete pTrainer; + + return 1; +} |