diff --git a/paddle/operators/adam_op.cc b/paddle/operators/adam_op.cc index e3db70ea12988..3572de06bd60f 100644 --- a/paddle/operators/adam_op.cc +++ b/paddle/operators/adam_op.cc @@ -43,10 +43,6 @@ class AdamOp : public framework::OperatorWithKernel { "Output(Moment1Out) of AdamOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"), "Output(Moment2Out) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Beta1PowOut"), - "Output(Beta1PowOut) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Beta2PowOut"), - "Output(Beta2PowOut) of AdamOp should not be null."); auto lr_dims = ctx->GetInputDim("LearningRate"); PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, @@ -72,8 +68,6 @@ class AdamOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims); - ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims); - ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims); } }; @@ -92,8 +86,6 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("ParamOut", "(Tensor) Output parameter"); AddOutput("Moment1Out", "(Tensor) Output first moment"); AddOutput("Moment2Out", "(Tensor) Output second moment"); - AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); - AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); AddAttr("beta1", "(float, default 0.9) " @@ -121,10 +113,8 @@ Adam updates: moment1_out = beta1 * moment1 + (1 − beta1) * grad moment2_out = beta2 * moment2 + (1 − beta2) * grad * grad -beta1_pow_out = beta1_pow * beta1 -beta2_pow_out = beta2_pow * beta2 learning_rate_t = learning_rate_t * - sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out) + sqrt(1 - beta2_pow) / (1 - beta1_pow) param_out = param - learning_rate_t * moment1/ (sqrt(moment2) + epsilon) References: diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index 789c2f14b3247..45938006db123 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -26,14 +26,10 @@ class AdamOpKernel : public framework::OpKernel { auto param_out_tensor = ctx.Output("ParamOut"); auto moment1_out_tensor = ctx.Output("Moment1Out"); auto moment2_out_tensor = ctx.Output("Moment2Out"); - auto beta1_pow_out_tensor = ctx.Output("Beta1PowOut"); - auto beta2_pow_out_tensor = ctx.Output("Beta2PowOut"); param_out_tensor->mutable_data(ctx.GetPlace()); moment1_out_tensor->mutable_data(ctx.GetPlace()); moment2_out_tensor->mutable_data(ctx.GetPlace()); - beta1_pow_out_tensor->mutable_data(ctx.GetPlace()); - beta2_pow_out_tensor->mutable_data(ctx.GetPlace()); float beta1 = ctx.Attr("beta1"); float beta2 = ctx.Attr("beta2"); @@ -56,18 +52,13 @@ class AdamOpKernel : public framework::OpKernel { auto param_out = framework::EigenVector::Flatten(*param_out_tensor); auto moment1_out = framework::EigenVector::Flatten(*moment1_out_tensor); auto moment2_out = framework::EigenVector::Flatten(*moment2_out_tensor); - auto beta1_pow_out = - framework::EigenVector::Flatten(*beta1_pow_out_tensor); - auto beta2_pow_out = - framework::EigenVector::Flatten(*beta2_pow_out_tensor); auto place = ctx.GetEigenDevice(); moment1_out.device(place) = beta1 * moment1 + (1 - beta1) * grad; moment2_out.device(place) = beta2 * moment2 + (1 - beta2) * grad.square(); - beta1_pow_out.device(place) = beta1_pow * beta1; - beta2_pow_out.device(place) = beta2_pow * beta2; + // All of these are tensors of 1 element - auto lr_t = lr * (1 - beta2_pow_out).sqrt() / (1 - beta1_pow_out); + auto lr_t = lr * (1 - beta2_pow).sqrt() / (1 - beta1_pow); // Eigen does not support automatic broadcast // Get dimensions of moment vector to broadcast lr_t Eigen::DSizes m_dsize(moment1_out_tensor->numel()); diff --git a/python/paddle/v2/framework/tests/test_adam_op.py b/python/paddle/v2/framework/tests/test_adam_op.py index ff6faafa6e211..a0d6655d4cbcf 100644 --- a/python/paddle/v2/framework/tests/test_adam_op.py +++ b/python/paddle/v2/framework/tests/test_adam_op.py @@ -33,14 +33,12 @@ def setUp(self): self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} - param_out, moment1_out, moment2_out, beta1_pow_out, \ - beta2_pow_out = adam_step(self.inputs, self.attrs) + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, self.attrs) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, - 'Beta1PowOut': beta1_pow_out, - 'Beta2PowOut': beta2_pow_out, 'ParamOut': param_out } @@ -78,14 +76,12 @@ def setUp(self): attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} - param_out, moment1_out, moment2_out, beta1_pow_out, \ - beta2_pow_out = adam_step(self.inputs, attributes) + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, attributes) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, - 'Beta1PowOut': beta1_pow_out, - 'Beta2PowOut': beta2_pow_out, 'ParamOut': param_out } @@ -127,14 +123,12 @@ def setUp(self): def test_check_output(self): for _ in range(self.num_steps): - param_out, moment1_out, moment2_out, beta1_pow_out, \ - beta2_pow_out = adam_step(self.inputs, self.attrs) + param_out, moment1_out, \ + moment2_out = adam_step(self.inputs, self.attrs) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, - 'Beta1PowOut': beta1_pow_out, - 'Beta2PowOut': beta2_pow_out, 'ParamOut': param_out } @@ -145,8 +139,10 @@ def test_check_output(self): self.inputs['Param'] = param_out self.inputs['Moment1'] = moment1_out self.inputs['Moment2'] = moment2_out - self.inputs['Beta1Pow'] = beta1_pow_out - self.inputs['Beta2Pow'] = beta2_pow_out + + # Update powers of Beta1 and Beta2 for next time step + self.inputs['Beta1Pow'] *= self.attrs['beta1'] + self.inputs['Beta2Pow'] *= self.attrs['beta1'] # Randomize gradient for next step self.inputs['Grad'] = np.random.uniform( @@ -175,11 +171,9 @@ def adam_step(inputs, attributes): moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) - beta1_pow_out = beta1_pow * beta1 - beta2_pow_out = beta2_pow * beta2 - lr_t = lr * np.sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) - return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out + return param_out, moment1_out, moment2_out if __name__ == "__main__":