diff options
author | Wu, Ke <wuke@cs.umd.edu> | 2014-12-17 16:15:13 -0500 |
---|---|---|
committer | Wu, Ke <wuke@cs.umd.edu> | 2014-12-17 16:15:13 -0500 |
commit | 6829a0bc624b02ebefc79f8cf9ec89d7d64a7c30 (patch) | |
tree | 125dfb20f73342873476c793995397b26fd202dd /training/const_reorder/trainer.cc | |
parent | b455a108a21f4ba5a58ab1bc53a8d2bf4d829067 (diff) | |
parent | 7468e8d85e99b4619442c7afaf4a0d92870111bb (diff) |
Merge branch 'const_reorder_2' into softsyn_2
Diffstat (limited to 'training/const_reorder/trainer.cc')
-rw-r--r-- | training/const_reorder/trainer.cc | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/training/const_reorder/trainer.cc b/training/const_reorder/trainer.cc new file mode 100644 index 00000000..89bd7479 --- /dev/null +++ b/training/const_reorder/trainer.cc @@ -0,0 +1,67 @@ +#include "trainer.h" + +Tsuruoka_Maxent_Trainer::Tsuruoka_Maxent_Trainer() + : const_reorder::Tsuruoka_Maxent(NULL) {} + +void Tsuruoka_Maxent_Trainer::fnTrain(const char* pszInstanceFName, + const char* pszAlgorithm, + const char* pszModelFName) { + assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 || + strcmp(pszAlgorithm, "sgd") == 0 || strcmp(pszAlgorithm, "SGD") == 0); + FILE* fpIn = fopen(pszInstanceFName, "r"); + + maxent::ME_Model* pModel = new maxent::ME_Model(); + + char* pszLine = new char[100001]; + int iNumInstances = 0; + int iLen; + while (!feof(fpIn)) { + pszLine[0] = '\0'; + fgets(pszLine, 20000, fpIn); + if (strlen(pszLine) == 0) { + continue; + } + + iLen = strlen(pszLine); + while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { + pszLine[iLen - 1] = '\0'; + iLen--; + } + + iNumInstances++; + + maxent::ME_Sample* pmes = new maxent::ME_Sample(); + + char* p = strrchr(pszLine, ' '); + assert(p != NULL); + p[0] = '\0'; + p++; + std::vector<std::string> vecContext; + SplitOnWhitespace(std::string(pszLine), &vecContext); + + pmes->label = std::string(p); + for (size_t i = 0; i < vecContext.size(); i++) + pmes->add_feature(vecContext[i]); + pModel->add_training_sample((*pmes)); + if (iNumInstances % 100000 == 0) + fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); + delete pmes; + } + fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); + fclose(fpIn); + + if (strcmp(pszAlgorithm, "l1") == 0) + pModel->use_l1_regularizer(1.0); + else if (strcmp(pszAlgorithm, "l2") == 0) + pModel->use_l2_regularizer(1.0); + else + pModel->use_SGD(); + + pModel->train(); + pModel->save_to_file(pszModelFName); + + delete pModel; + fprintf(stdout, "......Finished Training\n"); + fprintf(stdout, "......Model saved as %s\n", pszModelFName); + delete[] pszLine; +} |