Skip to content

Commit

Permalink
Merge pull request #963 from reyoung/feature/add_const_in_parameter_u…
Browse files Browse the repository at this point in the history
…pdater

Add const in ParameterUpdater init
  • Loading branch information
reyoung authored Dec 20, 2016
2 parents 2965df5 + 0d1703d commit dadd48a
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion paddle/parameter/ParameterUpdaterBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */

namespace paddle {

void ParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void ParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;
for (ParameterType type : getParameterTypes()) {
for (auto& para : parameters) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/parameter/ParameterUpdaterBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ParameterUpdater {
parameterTypes_.push_back(type);
}

virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);

// called by Trainer when starting a new pass
virtual void startPass() {}
Expand Down Expand Up @@ -105,7 +105,7 @@ class ParameterUpdaterComposite : public ParameterUpdater {
ParameterUpdaterComposite() {}
virtual ~ParameterUpdaterComposite() {}

virtual void init(std::vector<ParameterPtr>& parameters) = 0;
virtual void init(const std::vector<ParameterPtr>& parameters) = 0;

virtual void startPass() {
syncThreadPool_->execPlusOwner(
Expand Down
3 changes: 2 additions & 1 deletion paddle/trainer/ParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager(
updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); });
}

void SgdUpdaterWithCpuAverager::init(std::vector<ParameterPtr>& parameters) {
void SgdUpdaterWithCpuAverager::init(
const std::vector<ParameterPtr>& parameters) {
SgdLocalUpdater::init(parameters);
averager_->init(parameters_.size(), nullptr);
copyEvents_.resize(parameters_.size());
Expand Down
4 changes: 2 additions & 2 deletions paddle/trainer/ParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SgdLocalUpdater : public ParameterUpdater {
* be initialized.
* @param parameters The parameter need to be initialized.
*/
virtual void init(std::vector<ParameterPtr>& parameters) {
virtual void init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
optimizer_->init(parameters_.size(), nullptr);
// check no L1 decay in parameter configs
Expand Down Expand Up @@ -208,7 +208,7 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater {
* @brief init. Initialize cpu parameters, model average optimizer.
* @param parameters
*/
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);

virtual PassType startBatch(int64_t batchSize) {
averager_->startBatch(-1UL);
Expand Down
7 changes: 4 additions & 3 deletions paddle/trainer/RemoteParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater(
addParameterType(PARAMETER_MOMENTUM);
}

void RemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void RemoteParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);

if (localUpdater_) {
Expand Down Expand Up @@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater(
testing_(testing),
useApplyInPserver_(false) {}

void SparseRemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void SparseRemoteParameterUpdater::init(
const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);

parameterClient_.reset(new ParameterClient2(
Expand Down Expand Up @@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote(
}

void SparseRemoteParameterUpdaterComposite::init(
std::vector<ParameterPtr>& parameters) {
const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;

std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS];
Expand Down
6 changes: 3 additions & 3 deletions paddle/trainer/RemoteParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
/**
* initialize the internal parameter client and itself.
*/
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
/**
* @brief start batch
*
Expand Down Expand Up @@ -274,7 +274,7 @@ class SparseRemoteParameterUpdater : public ParameterUpdater {
}

/// initialization
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);

/// stateful batch control
virtual PassType startBatch(int64_t batchSize);
Expand Down Expand Up @@ -360,7 +360,7 @@ class SparseRemoteParameterUpdaterComposite : public ParameterUpdaterComposite {
}

/// initialization of dense and sparse updaters
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
};

class ParameterUpdaterCreators {
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/ThreadParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig)
}
}

void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) {
void SgdThreadUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);

// calc max parameter id
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/ThreadParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SgdThreadUpdater : public ParameterUpdater {
// Use the finishPass() function of the base optimizer.
virtual bool finishPass(real cost);

virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize);
// Call finishBatch for each optimizer.
virtual void finishBatch(real cost);
Expand Down

0 comments on commit dadd48a

Please sign in to comment.