Skip to content

Commit

Permalink
Refactor the code to make the compiler optimize it easily.
Browse files Browse the repository at this point in the history
  • Loading branch information
weixingzhang committed Nov 30, 2018
1 parent a01f92e commit 5f87427
Showing 1 changed file with 62 additions and 29 deletions.
91 changes: 62 additions & 29 deletions onnxruntime/core/graph/initializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,22 @@ class Initializer final {
return dims_;
}

size_t size() const { return size_; }
int64_t size() const { return size_; }

Initializer& add(float value) {
int64_t n = size();
switch (data_type_) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
for (int i = 0; i < size_; i++) {
data<float>()[i] += value;
float* dst = data<float>();
for (int i = 0; i < n; i++) {
dst[i] += value;
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
for (int i = 0; i < size_; i++) {
data<double>()[i] += value;
double* dst = data<double>();
for (int i = 0; i < n; i++) {
dst[i] += value;
}
break;
}
Expand All @@ -213,16 +216,21 @@ class Initializer final {
}

Initializer& add(const Initializer& other) {
int64_t n = size();
switch (data_type_) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
for (int i = 0; i < size_; i++) {
data<float>()[i] += other.data<float>()[i];
float* dst = data<float>();
const float* src = other.data<float>();
for (int i = 0; i < n; i++) {
dst[i] += src[i];
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
for (int i = 0; i < size_; i++) {
data<double>()[i] += other.data<double>()[i];
double* dst = data<double>();
const double* src = other.data<double>();
for (int i = 0; i < n; i++) {
dst[i] += src[i];
}
break;
}
Expand All @@ -232,16 +240,21 @@ class Initializer final {
return *this;
}
Initializer& sub(const Initializer& other) {
int64_t n = size();
switch (data_type_) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
for (int i = 0; i < size_; i++) {
data<float>()[i] -= other.data<float>()[i];
float* dst = data<float>();
const float* src = other.data<float>();
for (int i = 0; i < n; i++) {
dst[i] -= src[i];
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
for (int i = 0; i < size_; i++) {
data<double>()[i] -= other.data<double>()[i];
double* dst = data<double>();
const double* src = other.data<double>();
for (int i = 0; i < n; i++) {
dst[i] -= src[i];
}
break;
}
Expand All @@ -252,16 +265,21 @@ class Initializer final {
}

Initializer& mul(const Initializer& other) {
int64_t n = size();
switch (data_type_) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
for (int i = 0; i < size_; i++) {
data<float>()[i] *= other.data<float>()[i];
float* dst = data<float>();
const float* src = other.data<float>();
for (int i = 0; i < n; i++) {
dst[i] *= src[i];
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
for (int i = 0; i < size_; i++) {
data<double>()[i] *= other.data<double>()[i];
double* dst = data<double>();
const double* src = other.data<double>();
for (int i = 0; i < n; i++) {
dst[i] *= src[i];
}
break;
}
Expand All @@ -271,16 +289,21 @@ class Initializer final {
return *this;
}
Initializer& div(const Initializer& other) {
int64_t n = size();
switch (data_type_) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
for (int i = 0; i < size_; i++) {
data<float>()[i] /= other.data<float>()[i];
float* dst = data<float>();
const float* src = other.data<float>();
for (int i = 0; i < n; i++) {
dst[i] /= src[i];
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
for (int i = 0; i < size_; i++) {
data<double>()[i] /= other.data<double>()[i];
double* dst = data<double>();
const double* src = other.data<double>();
for (int i = 0; i < n; i++) {
dst[i] /= src[i];
}
break;
}
Expand All @@ -291,16 +314,19 @@ class Initializer final {
}

Initializer& sqrt() {
int64_t n = size();
switch (data_type_) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
for (int i = 0; i < size_; i++) {
data<float>()[i] = std::sqrt(data<float>()[i]);
float* dst = data<float>();
for (int i = 0; i < n; i++) {
dst[i] = std::sqrt(dst[i]);
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
for (int i = 0; i < size_; i++) {
data<double>()[i] = std::sqrt(data<double>()[i]);
double* dst = data<double>();
for (int i = 0; i < n; i++) {
dst[i] = std::sqrt(dst[i]);
}
break;
}
Expand All @@ -316,19 +342,26 @@ class Initializer final {
num *= dims_[k];
}

int64_t n = size()/num;
switch (data_type_) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
for (size_t i = 0; i < size() / num; i++) {
float* dst = data<float>();
const float* src = other.data<float>();
for (int i = 0; i < n; i++) {
int index = other.size() == 1 ? 0 : i;
for (int64_t j = 0; j < num; j++) {
data<float>()[i * num + j] *= other.data<float>()[other.size() == 1 ? 0 : i];
dst[i * num + j] *= src[index];
}
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
for (size_t i = 0; i < size() / num; i++) {
double* dst = data<double>();
const double* src = other.data<double>();
for (int i = 0; i < n; i++) {
int index = other.size() == 1 ? 0 : i;
for (int64_t j = 0; j < num; j++) {
data<double>()[i * num + j] *= other.data<double>()[other.size() == 1 ? 0 : i];
dst[i * num + j] *= src[index];
}
}
break;
Expand Down

0 comments on commit 5f87427

Please sign in to comment.