Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: BVLC/caffe
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 9478a485dcdf142eca32d2a7d8231856b42c6d6c
Choose a base ref
..
head repository: BVLC/caffe
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: dcd0e838cae6b037a23920154024c96338c5587e
Choose a head ref
Showing with 6 additions and 9 deletions.
  1. +3 −6 include/caffe/solver.hpp
  2. +3 −3 src/caffe/solver.cpp
9 changes: 3 additions & 6 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
@@ -35,9 +35,6 @@ class Solver {
int iter() { return iter_; }

protected:
// PreSolve is run before any solving iteration starts, allowing one to
// put up some scaffold.
virtual void PreSolve() {}
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
@@ -74,14 +71,14 @@ template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) {}
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) {}
: Solver<Dtype>(param_file) { PreSolve(); }

const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

protected:
virtual void PreSolve();
void PreSolve();
Dtype GetLearningRate();
virtual void ComputeUpdateValue();
virtual void SnapshotSolverState(SolverState * state);
6 changes: 3 additions & 3 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
@@ -42,7 +42,6 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
LOG(INFO) << "Solver scaffolding done.";
iter_ = 0;
current_step_ = 0;
PreSolve();
}

template <typename Dtype>
@@ -160,7 +159,6 @@ void Solver<Dtype>::InitTestNets() {

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
vector<Blob<Dtype>*> bottom_vec;
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
@@ -175,7 +173,9 @@ void Solver<Dtype>::Step(int iters) {

const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info());
Dtype loss = net_->ForwardBackward(bottom_vec);
Dtype loss;
net_->ForwardPrefilled(&loss);
net_->Backward();
if (losses.size() < average_loss) {
losses.push_back(loss);
int size = losses.size();