Skip to content

Commit

Permalink
[Onnx] Fix NLL Loss tests (apache#8971)
Browse files Browse the repository at this point in the history
* support negatibve indices in gather

* move check to Tensor level indexing, gathernd

* add test, update transform.h

* remove unneeded gather

* missing gather nd change

* update tests

* proper tensor comparison

* blacking

* lint

* fix error

* turn on test

* missing test case

* revert changes

* add normalize_gather_indices

* undo change

* update

* more removing diffs

* more undoing

Co-authored-by: Andrew Zhao Luo <[email protected]>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 68b4c2f commit 1971e2e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
30 changes: 24 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,12 @@ def _impl_common(cls, data, indices, batch_dims=0):
indices_shape = infer_shape(indices)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
index_rank = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, index_rank)
return _op.gather_nd(
data,
indices,
batch_dims=batch_dims,
index_rank=index_rank,
)

@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand Down Expand Up @@ -3554,6 +3559,11 @@ def _impl_v13(cls, inputs, attr, params):
)

input_tensor, target_tensor = inputs[0], inputs[1]

# Convert negative indices --> positive indices for gather ops, note we have to
# use the original target tensor to interact with ignore_index to have proper behavior.
normalized_target_tensor = normalize_gather_indices(input_tensor, target_tensor, 1)

if len(inputs) == 3:
weight_tensor = inputs[2]
else:
Expand All @@ -3563,12 +3573,18 @@ def _impl_v13(cls, inputs, attr, params):
dtype=input_tensor.type_annotation.dtype,
)

loss = -relay.gather(input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1))
loss = -relay.gather(
input_tensor,
axis=1,
indices=relay.expand_dims(normalized_target_tensor, 1),
)
loss = relay.squeeze(loss, axis=[1])

expanded_target_tensor = relay.expand_dims(target_tensor, 0)
expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor)
flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor)
expanded_normalized_target_tensor = relay.expand_dims(normalized_target_tensor, 0)
expanded_normalized_target_tensor = relay.nn.batch_flatten(
expanded_normalized_target_tensor
)
flattened_weights = relay.gather_nd(weight_tensor, expanded_normalized_target_tensor)
select_weights = relay.reshape_like(flattened_weights, loss)
loss *= select_weights

Expand All @@ -3578,7 +3594,9 @@ def _impl_v13(cls, inputs, attr, params):
target_tensor, relay.const(ignore_index, dtype=target_tensor.type_annotation.dtype)
)
mask_tensor = relay.const(1, dtype="int8") - relay.cast(mask_tensor, "int8")
loss *= relay.cast_like(mask_tensor, loss)
loss = relay.where(
mask_tensor, loss, relay.const(0, infer_type(loss).checked_type.dtype)
)

# This is not explained super clearly in the onnx spec, but masked values don't
# contribute toward the final value in reduction
Expand Down
5 changes: 0 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4771,11 +4771,6 @@ def verify_eyelike(indata):
"test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded",
"test_nllloss_NCd1d2d3d4d5_mean_weight_expanded",
"test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded",
# These nllloss tests are flaky and sometimes gives NaNs
# Investigate it here: https://github.com/apache/tvm/issues/8918
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii",
# Investigate it here: https://github.com/apache/tvm/issues/8964
"test_nllloss_NCd1d2d3_sum_weight_high_ii",
"test_qlinearmatmul_2D",
"test_qlinearmatmul_3D",
"test_range_float_type_positive_delta_expanded",
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from tvm.relay.loops import while_loop
from tvm.relay.testing import run_infer_type as infer_type

from utils.assert_diagnostic import DiagnosticTesting
from utils import ref_funcs
from utils.assert_diagnostic import DiagnosticTesting


def int32(val):
Expand Down Expand Up @@ -2022,7 +2022,7 @@ def test_gather_nd():
def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0):
x = relay.var("x", relay.TensorType(data_shape, "float32"))
y = relay.var("y", relay.TensorType(indices_shape, "int32"))
z = relay.gather_nd(x, y, batch_dims, indices_shape[0])
z = relay.gather_nd(x, y, batch_dims=batch_dims, index_rank=indices_shape[0])

mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], z)
Expand Down

0 comments on commit 1971e2e

Please sign in to comment.