Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower Lerp #2972

Merged
merged 4 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion OP_LOWERING_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the e
Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to `xla/test/test_operations.py` if it is required. We also need to add CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the `aten::op` and `xla::op` counters.

## Tips
The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/2969). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/1940).
The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/2969). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/2972).
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
109 changes: 109 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10285,5 +10285,114 @@ TEST_F(AtenXlaTensorTest, TestEarlySyncLiveTensors) {
cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerp) {
torch::Tensor start =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor res = torch::lerp(start, end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_start = CopyToDevice(start, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
torch::Tensor xla_res = torch::lerp(xla_start, xla_end, xla_weight);
AllClose(res, xla_res);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpScalar) {
torch::Tensor start =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Scalar weight = torch::Scalar(3.0);
torch::Tensor res = torch::lerp(start, end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_start = CopyToDevice(start, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_res = torch::lerp(xla_start, xla_end, weight);
AllClose(res, xla_res);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpInplace) {
torch::Tensor input =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor input_copy = input.clone();
input.lerp_(end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input_copy, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
xla_input.lerp_(xla_end, xla_weight);
AllClose(xla_input, input);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpScalarInplace) {
torch::Tensor input =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Scalar weight = torch::Scalar(3.0);
torch::Tensor input_copy = input.clone();
input.lerp_(end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input_copy, device);
torch::Tensor xla_end = CopyToDevice(end, device);
xla_input.lerp_(xla_end, weight);
AllClose(xla_input, input);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpOut) {
torch::Tensor start =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat));
;
torch::lerp_out(res, start, end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_start = CopyToDevice(start, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options());
torch::lerp_out(xla_res, xla_start, xla_end, xla_weight);
AllClose(res, xla_res);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpScalarOut) {
torch::Tensor start =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Scalar weight = torch::Scalar(3.0);
torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat));
torch::lerp_out(res, start, end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_start = CopyToDevice(start, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options());
torch::lerp_out(xla_res, xla_start, xla_end, weight);
AllClose(res, xla_res);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp_out", cpp_test::GetIgnoredCounters());
}

} // namespace cpp_test
} // namespace torch_xla
50 changes: 50 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,56 @@ at::Tensor leaky_relu_backward(const at::Tensor& grad_output,
negative_slope.to<double>()));
}

at::Tensor lerp(const at::Tensor& self, const at::Tensor& end,
const at::Tensor& weight) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::lerp(bridge::GetXlaTensor(self), bridge::GetXlaTensor(end),
bridge::GetXlaTensor(weight)));
}

at::Tensor lerp(const at::Tensor& self, const at::Tensor& end,
const at::Scalar& weight) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::lerp(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), weight));
}

at::Tensor& lerp_(at::Tensor& self, const at::Tensor& end,
const at::Tensor& weight) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::lerp_(self_tensor, bridge::GetXlaTensor(end),
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
bridge::GetXlaTensor(weight));
return self;
}

at::Tensor& lerp_(at::Tensor& self, const at::Tensor& end,
const at::Scalar& weight) {
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::lerp_(self_tensor, bridge::GetXlaTensor(end), weight);
return self;
}

at::Tensor& lerp_out(const at::Tensor& self, const at::Tensor& end,
const at::Tensor& weight, at::Tensor& out) {
XLA_FN_COUNTER("xla::");
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor::lerp_out(out_tensor, bridge::GetXlaTensor(self),
bridge::GetXlaTensor(end), bridge::GetXlaTensor(weight));
return out;
}

at::Tensor& lerp_out(const at::Tensor& self, const at::Tensor& end,
const at::Scalar& weight, at::Tensor& out) {
XLA_FN_COUNTER("xla::");
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor::lerp_out(out_tensor, bridge::GetXlaTensor(self),
bridge::GetXlaTensor(end), weight);
return out;
}

at::Tensor log(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::log(bridge::GetXlaTensor(self)));
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,11 @@ NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias,
std::move(lower_fn));
}

NodePtr Lerp(const Value& start, const Value& end, const Value& weight) {
ScopePusher ir_scope(at::aten::lerp.toQualString());
return start + weight * (end - start);
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ NodePtr IsNan(const Value& input);
NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias,
const Value& product_multiplier, const Value& bias_multiplier);

NodePtr Lerp(const Value& start, const Value& end, const Value& weight);

} // namespace ops
} // namespace ir
} // namespace torch_xla
13 changes: 13 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,19 @@ class XLATensor {
const XLATensor& input,
double negative_slope);

static XLATensor lerp(const XLATensor& input, const XLATensor& end,
const XLATensor& weight);
static XLATensor lerp(const XLATensor& input, const XLATensor& end,
const at::Scalar& weight);
static void lerp_(XLATensor& input, const XLATensor& end,
const XLATensor& weight);
static void lerp_(XLATensor& input, const XLATensor& end,
const at::Scalar& weight);
static void lerp_out(XLATensor& out, const XLATensor& input,
const XLATensor& end, const XLATensor& weight);
static void lerp_out(XLATensor& out, const XLATensor& input,
const XLATensor& end, const at::Scalar& weight);

static XLATensor log(const XLATensor& input);

static XLATensor log_base(const XLATensor& input, ir::OpKind op, double base);
Expand Down
42 changes: 42 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,48 @@ XLATensor XLATensor::leaky_relu_backward(const XLATensor& grad_output,
grad_output.GetIrValue(), input.GetIrValue(), negative_slope));
}

XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end,
const XLATensor& weight) {
return input.CreateFrom(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
}

XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end,
const at::Scalar& weight) {
ir::Value weight_val = GetIrValueForScalar(
weight, input.shape().get().element_type(), input.GetDevice());
return input.CreateFrom(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
}

void XLATensor::lerp_(XLATensor& input, const XLATensor& end,
const XLATensor& weight) {
input.SetInPlaceIrValue(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
}

void XLATensor::lerp_(XLATensor& input, const XLATensor& end,
const at::Scalar& weight) {
ir::Value weight_val = GetIrValueForScalar(
weight, input.shape().get().element_type(), input.GetDevice());
input.SetInPlaceIrValue(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
}

void XLATensor::lerp_out(XLATensor& out, const XLATensor& input,
const XLATensor& end, const XLATensor& weight) {
out.SetInPlaceIrValue(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
}

void XLATensor::lerp_out(XLATensor& out, const XLATensor& input,
const XLATensor& end, const at::Scalar& weight) {
ir::Value weight_val = GetIrValueForScalar(
weight, input.shape().get().element_type(), input.GetDevice());
out.SetInPlaceIrValue(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
}

XLATensor XLATensor::log(const XLATensor& input) {
return input.CreateFrom(ir::ops::Log(input.GetIrValue()));
}
Expand Down
6 changes: 6 additions & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ supported:
- sigmoid_backward
- tanh_backward
- ger
- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor
autograd:
- max_pool2d
- max_pool3d