diff options
Diffstat (limited to 'utils/synutils/argument_reorder_model.cc')
-rw-r--r-- | utils/synutils/argument_reorder_model.cc | 311 |
1 files changed, 311 insertions, 0 deletions
diff --git a/utils/synutils/argument_reorder_model.cc b/utils/synutils/argument_reorder_model.cc new file mode 100644 index 00000000..e30ee971 --- /dev/null +++ b/utils/synutils/argument_reorder_model.cc @@ -0,0 +1,311 @@ +/* + * argument_reorder_model.cc + * + * Created on: Dec 15, 2013 + * Author: lijunhui + */ + +#include <boost/program_options.hpp> +#include <fstream> + +#include "argument_reorder_model.h" +#include "utility.h" +#include "tsuruoka_maxent.h" + +inline void fnPreparingTrainingdata(const char* pszFName, int iCutoff, const char* pszNewFName) { + SFReader *pFReader = new STxtFileReader(pszFName); + char *pszLine = new char[ 100001 ]; + int iLen; + Map hashPredicate; + while (pFReader->fnReadNextLine(pszLine, &iLen)) { + if (iLen == 0) + continue; + + vector<string> vecTerms; + SplitOnWhitespace(string(pszLine), &vecTerms); + + for (size_t i = 0; i < vecTerms.size() - 1; i++) { + Iterator iter = hashPredicate.find(vecTerms[i]); + if (iter == hashPredicate.end()) { + hashPredicate[vecTerms[i]] = 1; + + } else { + iter->second++; + } + } + } + delete pFReader; + + pFReader = new STxtFileReader(pszFName); + FILE *fpOut = fopen(pszNewFName, "w"); + while (pFReader->fnReadNextLine(pszLine, &iLen)) { + if (iLen == 0) + continue; + + vector<string> vecTerms; + SplitOnWhitespace(string(pszLine), &vecTerms); + ostringstream ostr; + for (size_t i = 0; i < vecTerms.size() - 1; i++) { + Iterator iter = hashPredicate.find(vecTerms[i]); + assert(iter != hashPredicate.end()); + if (iter->second >= iCutoff) { + ostr << vecTerms[i] << " "; + } + } + if (ostr.str().length() > 0) { + ostr << vecTerms[vecTerms.size() - 1]; + fprintf(fpOut, "%s\n", ostr.str().c_str()); + } + } + fclose(fpOut); + delete pFReader; + + + delete [] pszLine; +} + +struct SArgumentReorderTrainer{ + SArgumentReorderTrainer(const char* pszSRLFname, //source-side srl tree file name + const char* pszAlignFname, //alignment filename + const char* pszSourceFname, //source file name + const char* pszTargetFname, //target file name + const char* pszTopPredicateFname, //target file name + const char* pszInstanceFname, //training instance file name + const char* pszModelFname, //classifier model file name + int iCutoff) { + fnGenerateInstanceFiles(pszSRLFname, pszAlignFname, pszSourceFname, pszTargetFname, pszTopPredicateFname, pszInstanceFname); + + string strInstanceFname, strModelFname; + strInstanceFname = string(pszInstanceFname) + string(".left"); + strModelFname = string(pszModelFname) + string(".left"); + fnTraining(strInstanceFname.c_str(), strModelFname.c_str(), iCutoff); + strInstanceFname = string(pszInstanceFname) + string(".right"); + strModelFname = string(pszModelFname) + string(".right"); + fnTraining(strInstanceFname.c_str(), strModelFname.c_str(), iCutoff); + } + + ~SArgumentReorderTrainer() { + + } + +private: + void fnTraining(const char* pszInstanceFname, const char* pszModelFname, int iCutoff) { + char *pszNewInstanceFName = new char[strlen(pszInstanceFname) + 50]; + if (iCutoff > 0) { + sprintf(pszNewInstanceFName, "%s.tmp", pszInstanceFname); + fnPreparingTrainingdata(pszInstanceFname, iCutoff, pszNewInstanceFName); + } else { + strcpy(pszNewInstanceFName, pszInstanceFname); + } + + Tsuruoka_Maxent *pMaxent = new Tsuruoka_Maxent(NULL); + pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname, 300); + delete pMaxent; + + if (strcmp(pszNewInstanceFName, pszInstanceFname) != 0) { + sprintf(pszNewInstanceFName, "rm %s.tmp", pszInstanceFname); + system(pszNewInstanceFName); + } + delete [] pszNewInstanceFName; + } + + void fnGenerateInstanceFiles(const char* pszSRLFname, //source-side flattened parse tree file name + const char* pszAlignFname, //alignment filename + const char* pszSourceFname, //source file name + const char* pszTargetFname, //target file name + const char* pszTopPredicateFname, //top predicate file name (we only consider predicates with 100+ occurrences + const char* pszInstanceFname //training instance file name + ) { + SAlignmentReader *pAlignReader = new SAlignmentReader(pszAlignFname); + SSrlSentenceReader *pSRLReader = new SSrlSentenceReader(pszSRLFname); + STxtFileReader *pTxtSReader = new STxtFileReader(pszSourceFname); + STxtFileReader *pTxtTReader = new STxtFileReader(pszTargetFname); + + Map *pMapPredicate; + if (pszTopPredicateFname != NULL) + pMapPredicate = fnLoadTopPredicates(pszTopPredicateFname); + else + pMapPredicate = NULL; + + char *pszLine = new char[50001]; + + FILE *fpLeftOut, *fpRightOut; + sprintf(pszLine, "%s.left", pszInstanceFname); + fpLeftOut = fopen(pszLine, "w"); + sprintf(pszLine, "%s.right", pszInstanceFname); + fpRightOut = fopen(pszLine, "w"); + + //read sentence by sentence + SAlignment *pAlign; + SSrlSentence *pSRL; + SParsedTree *pTree; + int iSentNum = 0; + while ((pAlign = pAlignReader->fnReadNextAlignment()) != NULL) { + pSRL = pSRLReader->fnReadNextSrlSentence(); + assert(pSRL != NULL); + pTree = pSRL->m_pTree; + assert(pTxtSReader->fnReadNextLine(pszLine, NULL)); + vector<string> vecSTerms; + SplitOnWhitespace(string(pszLine), &vecSTerms); + assert(pTxtTReader->fnReadNextLine(pszLine, NULL)); + vector<string> vecTTerms; + SplitOnWhitespace(string(pszLine), &vecTTerms); + //vecTPOSTerms.size() == 0, given the case when an english sentence fails parsing + + if (pTree != NULL) { + for (size_t i = 0; i < pSRL->m_vecPred.size(); i++) { + SPredicate *pPred = pSRL->m_vecPred[i]; + if (strcmp(pTree->m_vecTerminals[pPred->m_iPosition]->m_ptParent->m_pszTerm, "VA") == 0) continue; + string strPred = string(pTree->m_vecTerminals[pPred->m_iPosition]->m_pszTerm); + if (pMapPredicate != NULL) { + Map::iterator iter_map = pMapPredicate->find(strPred); + if (pMapPredicate != NULL && iter_map == pMapPredicate->end()) + continue; + } + + SPredicateItem *pPredItem = new SPredicateItem(pTree, pPred); + + vector<string> vecStrBlock; + for (size_t j = 0; j < pPredItem->vec_items_.size(); j++) { + SSRLItem *pItem1 = pPredItem->vec_items_[j]; + vecStrBlock.push_back(SArgumentReorderModel::fnGetBlockOutcome(pItem1->tree_item_->m_iBegin, pItem1->tree_item_->m_iEnd, pAlign)); + } + + vector<string> vecStrLeftReorderType; + vector<string> vecStrRightReorderType; + SArgumentReorderModel::fnGetReorderType(pPredItem, pAlign, vecStrLeftReorderType, vecStrRightReorderType); + for (int j = 1; j < pPredItem->vec_items_.size(); j++) { + string strLeftOutcome, strRightOutcome; + strLeftOutcome = vecStrLeftReorderType[j - 1]; + strRightOutcome = vecStrRightReorderType[j - 1]; + ostringstream ostr; + SArgumentReorderModel::fnGenerateFeature(pTree, pPred, pPredItem, j, vecStrBlock[j - 1], vecStrBlock[j], ostr); + + //fprintf(stderr, "%s %s\n", ostr.str().c_str(), strOutcome.c_str()); + //fprintf(fpOut, "sentid=%d %s %s\n", iSentNum, ostr.str().c_str(), strOutcome.c_str()); + fprintf(fpLeftOut, "%s %s\n", ostr.str().c_str(), strLeftOutcome.c_str()); + fprintf(fpRightOut, "%s %s\n", ostr.str().c_str(), strRightOutcome.c_str()); + } + } + } + delete pSRL; + + delete pAlign; + iSentNum++; + + if (iSentNum % 100000 == 0) + fprintf(stderr, "#%d\n", iSentNum); + } + delete [] pszLine; + + fclose(fpLeftOut); + fclose(fpRightOut); + + + delete pAlignReader; + delete pSRLReader; + delete pTxtSReader; + delete pTxtTReader; + } + + + + Map* fnLoadTopPredicates(const char* pszTopPredicateFname) { + if (pszTopPredicateFname == NULL) + return NULL; + + Map* pMapPredicate = new Map(); + STxtFileReader *pReader = new STxtFileReader(pszTopPredicateFname); + char *pszLine = new char[50001]; + int iNumCount = 0; + while (pReader->fnReadNextLine(pszLine, NULL)) { + if (pszLine[0] == '#') + continue; + char *p = strchr(pszLine, ' '); + assert(p != NULL); + p[0] = '\0'; + p++; + int iCount = atoi(p); + if (iCount < 100) + break; + (*pMapPredicate)[string(pszLine)] = iNumCount++; + } + delete pReader; + delete [] pszLine; + return pMapPredicate; + } +}; + + + +namespace po = boost::program_options; + +inline void print_options(std::ostream &out,po::options_description const& opts) { + typedef std::vector< boost::shared_ptr<po::option_description> > Ds; + Ds const& ds=opts.options(); + out << '"'; + for (unsigned i=0;i<ds.size();++i) { + if (i) out<<' '; + out<<"--"<<ds[i]->long_name(); + } + out << '"\n'; +} +inline string str(char const* name,po::variables_map const& conf) { + return conf[name].as<string>(); +} + +//--srl_file /scratch0/mt_exp/gale-align/gale-align.nw.srl.cn --align_file /scratch0/mt_exp/gale-align/gale-align.nw.al --source_file /scratch0/mt_exp/gale-align/gale-align.nw.cn --target_file /scratch0/mt_exp/gale-align/gale-align.nw.en --instance_file /scratch0/mt_exp/gale-align/gale-align.nw.argreorder.instance --model_prefix /scratch0/mt_exp/gale-align/gale-align.nw.argreorder.model --feature_cutoff 2 +//--srl_file /scratch0/mt_exp/gale-ctb/gale-ctb.srl.cn --align_file /scratch0/mt_exp/gale-ctb/gale-ctb.align --source_file /scratch0/mt_exp/gale-ctb/gale-ctb.cn --target_file /scratch0/mt_exp/gale-ctb/gale-ctb.en0 --instance_file /scratch0/mt_exp/gale-ctb/gale-ctb.argreorder.instance --model_prefix /scratch0/mt_exp/gale-ctb/gale-ctb.argreorder.model --feature_cutoff 2 +int main(int argc, char** argv) { + + po::options_description opts("Configuration options"); + opts.add_options() + ("srl_file",po::value<string>(),"srl file path (input)") + ("align_file",po::value<string>(),"Alignment file path (input)") + ("source_file",po::value<string>(),"Source text file path (input)") + ("target_file",po::value<string>(),"Target text file path (input)") + ("instance_file",po::value<string>(),"Instance file path (output)") + ("model_prefix",po::value<string>(),"Model file path prefix (output): three files will be generated") + ("feature_cutoff",po::value<int>()->default_value(100),"Feature cutoff threshold") + ("help", "produce help message"); + + po::variables_map vm; + if (argc) { + po::store(po::parse_command_line(argc, argv, opts), vm); + po::notify(vm); + } + + if (vm.count("help")) { + print_options(cout, opts); + return 1; + } + + if (!vm.count("srl_file") + || !vm.count("align_file") + || !vm.count("source_file") + || !vm.count("target_file") + || !vm.count("instance_file") + || !vm.count("model_prefix") + ) { + print_options(cout, opts); + if (!vm.count("parse_file")) cout << "--parse_file NOT FOUND\n"; + if (!vm.count("align_file")) cout << "--align_file NOT FOUND\n"; + if (!vm.count("source_file")) cout << "--source_file NOT FOUND\n"; + if (!vm.count("target_file")) cout << "--target_file NOT FOUND\n"; + if (!vm.count("instance_file")) cout << "--instance_file NOT FOUND\n"; + if (!vm.count("model_prefix")) cout << "--model_prefix NOT FOUND\n"; + exit(0); + } + + SArgumentReorderTrainer *pTrainer = new SArgumentReorderTrainer(str("srl_file", vm).c_str(), + str("align_file", vm).c_str(), + str("source_file", vm).c_str(), + str("target_file", vm).c_str(), + NULL, + str("instance_file", vm).c_str(), + str("model_prefix", vm).c_str(), + vm["feature_cutoff"].as<int>()); + delete pTrainer; + + return 1; +} |