Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update precision in the ONNX strided_slice, update precision of ToScalar #6272

Merged
merged 2 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,8 +1045,8 @@ def _impl_v1(cls, inputs, attr, params):
end = list(attr['ends'])

return _op.strided_slice(inputs[0],
begin=_expr.const(begin, dtype="int32"),
end=_expr.const(end, dtype="int32"))
begin=_expr.const(begin, dtype="int64"),
end=_expr.const(end, dtype="int64"))

@classmethod
def _impl_v10(cls, inputs, attr, params):
Expand All @@ -1063,8 +1063,8 @@ def _impl_v10(cls, inputs, attr, params):
starts = new_starts
ends = new_ends
return _op.strided_slice(inputs[0],
begin=_expr.const(starts, dtype="int32"),
end=_expr.const(ends, dtype="int32"))
begin=_expr.const(starts, dtype="int64"),
end=_expr.const(ends, dtype="int64"))


class Gather(OnnxOpConverter):
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
* \param i element index
* \return Converted scalar value.
*/
static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) {
static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why double is not sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because a double only has 52 bits of mantissa, it can't store the full precision of an int64_t.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we get rounding errors if we pass in large int64_t values

Copy link
Contributor Author

@mbrookhart mbrookhart Aug 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On x86, long double has 63 bits of mantissa and 1 bit of sign, just like int64. On PowerPC and ARM, it's a 128bit floating point with 106 bits of mantissa.

if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<int8_t*>(array->data)[i];
Expand Down Expand Up @@ -423,8 +423,8 @@ static inline Array<Integer> ToVector(const runtime::NDArray& array) {
size_t len = array.Shape().front();
Array<Integer> out;
for (size_t i = 0; i < len; ++i) {
double elem_val = ToScalar(array, i);
out.push_back(Integer(static_cast<int>(elem_val)));
long double elem_val = ToScalar(array, i);
out.push_back(Integer(IntImm(DataType::Int(32), static_cast<int64_t>(elem_val))));
}
return out;
}
Expand Down
11 changes: 6 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,15 @@ def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None):
inputs = [
helper.make_tensor_value_info("data", TensorProto.FLOAT,
list(indata.shape)),
helper.make_tensor_value_info("starts", TensorProto.INT32,
helper.make_tensor_value_info("starts", TensorProto.INT64,
list(starts.shape)),
helper.make_tensor_value_info("ends", TensorProto.INT32,
helper.make_tensor_value_info("ends", TensorProto.INT64,
list(ends.shape))
]
initializer = [
helper.make_tensor("starts", TensorProto.INT32, list(starts.shape),
helper.make_tensor("starts", TensorProto.INT64, list(starts.shape),
starts),
helper.make_tensor("ends", TensorProto.INT32, list(ends.shape), ends)
helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends)
]

if axes:
Expand Down Expand Up @@ -534,7 +534,8 @@ def test_slice():
_test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration_v10(x, x[:, 0:-1], (0), (-1), (1))
x = np.random.randn(1, 1, 1, 128).astype(np.float32)
_test_slice_iteration_v10(x, x, (0, 0), (9223372036854775807, 9223372036854775807), (0, 3))


def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
Expand Down