diff options
author | Chris Dyer <redpony@gmail.com> | 2014-12-17 22:04:58 -0500 |
---|---|---|
committer | Chris Dyer <redpony@gmail.com> | 2014-12-17 22:04:58 -0500 |
commit | a059245e63178682624e10cde69f171820bd9209 (patch) | |
tree | a6f17da7c69048c8900260b5490bb9d8611be3bb /training/const_reorder/trainer.cc | |
parent | 1a79175f9a101d46cf27ca921213d5dd9300518f (diff) | |
parent | 7468e8d85e99b4619442c7afaf4a0d92870111bb (diff) |
Merge pull request #63 from kho/const_reorder_2
Soft linguistic reordering constraints from http://www.aclweb.org/anthology/P14-1106
Diffstat (limited to 'training/const_reorder/trainer.cc')
-rw-r--r-- | training/const_reorder/trainer.cc | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/training/const_reorder/trainer.cc b/training/const_reorder/trainer.cc new file mode 100644 index 00000000..89bd7479 --- /dev/null +++ b/training/const_reorder/trainer.cc @@ -0,0 +1,67 @@ +#include "trainer.h" + +Tsuruoka_Maxent_Trainer::Tsuruoka_Maxent_Trainer() + : const_reorder::Tsuruoka_Maxent(NULL) {} + +void Tsuruoka_Maxent_Trainer::fnTrain(const char* pszInstanceFName, + const char* pszAlgorithm, + const char* pszModelFName) { + assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 || + strcmp(pszAlgorithm, "sgd") == 0 || strcmp(pszAlgorithm, "SGD") == 0); + FILE* fpIn = fopen(pszInstanceFName, "r"); + + maxent::ME_Model* pModel = new maxent::ME_Model(); + + char* pszLine = new char[100001]; + int iNumInstances = 0; + int iLen; + while (!feof(fpIn)) { + pszLine[0] = '\0'; + fgets(pszLine, 20000, fpIn); + if (strlen(pszLine) == 0) { + continue; + } + + iLen = strlen(pszLine); + while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { + pszLine[iLen - 1] = '\0'; + iLen--; + } + + iNumInstances++; + + maxent::ME_Sample* pmes = new maxent::ME_Sample(); + + char* p = strrchr(pszLine, ' '); + assert(p != NULL); + p[0] = '\0'; + p++; + std::vector<std::string> vecContext; + SplitOnWhitespace(std::string(pszLine), &vecContext); + + pmes->label = std::string(p); + for (size_t i = 0; i < vecContext.size(); i++) + pmes->add_feature(vecContext[i]); + pModel->add_training_sample((*pmes)); + if (iNumInstances % 100000 == 0) + fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); + delete pmes; + } + fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); + fclose(fpIn); + + if (strcmp(pszAlgorithm, "l1") == 0) + pModel->use_l1_regularizer(1.0); + else if (strcmp(pszAlgorithm, "l2") == 0) + pModel->use_l2_regularizer(1.0); + else + pModel->use_SGD(); + + pModel->train(); + pModel->save_to_file(pszModelFName); + + delete pModel; + fprintf(stdout, "......Finished Training\n"); + fprintf(stdout, "......Model saved as %s\n", pszModelFName); + delete[] pszLine; +} |