diff --git a/src/cxxnet_main.cpp b/src/cxxnet_main.cpp index 95df6e63..f6d2dda8 100644 --- a/src/cxxnet_main.cpp +++ b/src/cxxnet_main.cpp @@ -13,6 +13,7 @@ #include "nnet/nnet.h" #include "io/data.h" #include "utils/config.h" +#include #if MSHADOW_DIST_PS #include "ps.h" @@ -178,23 +179,39 @@ class CXXNetLearnTask { } // load in latest model from model_folder inline int SyncLastestModel(void) { - dmlc::Stream *fi = NULL, *last = NULL; - int s_counter = start_counter; - do{ - if (last != NULL) delete last; - last = fi; - std::ostringstream os; - os << name_model_dir << '/' << std::setfill('0') - << std::setw(4) << s_counter++ << ".model"; - fi = dmlc::Stream::Create(os.str().c_str(), "r", true); - } while (fi != NULL); + dmlc::Stream *fi = NULL, *next = NULL; + if (name_model_in != "NULL") + { + if (name_model_in.find('/') == -1 ) + name_model_in = name_model_dir + "/" + name_model_in; + fi = dmlc::Stream::Create( name_model_in.c_str(), "r", true); + const char* counter = strrchr(name_model_in.c_str(), '/'); + sscanf(counter+1, "%d", &start_counter); + } + if (fi == NULL) + { + char model_path[50]; + sprintf(model_path, "%s/%04d.model", name_model_dir.c_str(), start_counter); + fi = dmlc::Stream::Create(model_path, "r", true); + if (fi == NULL) + { + start_counter = 1 - save_period; + do{ + start_counter += save_period; + if (fi != NULL) delete fi; + fi = next; + sprintf(model_path, "%s/%04d.model", name_model_dir.c_str(), start_counter); + next = dmlc::Stream::Create(model_path, "r", true); + } while (next != NULL); + start_counter -= save_period; + } + } - if (last != NULL) { - CHECK(last->Read(&net_type, sizeof(int)) != 0) << "invalid model format"; + if (fi != NULL) { + CHECK(fi->Read(&net_type, sizeof(int)) != 0) << "invalid model format"; net_trainer = this->CreateNet(); - net_trainer->LoadModel(*last); - start_counter = s_counter - 1; - delete last; + net_trainer->LoadModel(*fi); + delete fi; return 1; } else { return 0;