-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle empty tensors during definition of cat #3313
base: main
Are you sure you want to change the base?
Conversation
This helped determine why the original repro failed
Co-authored-by: Naoya Maruyama <[email protected]>
!build --diff |
@@ -479,6 +481,21 @@ bool hasSimilarDtype(DataType base, DataType dt) { | |||
NVF_THROW("Unrecognized base dtype."); | |||
} | |||
|
|||
Val* zeroForDtype(DataType dtype) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use Fusion::zeroVal(DatatType)
?
if (dim == cat_dim) { | ||
shape[dim] = SimplifyingIrBuilder::addExpr(shape[dim], extent); | ||
} else if (shape[dim] == nullptr) { | ||
shape[dim] = extent; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what the correct behavior should be here. Can we just use the first extent for a non-cat dimension? Suppose the first cat input has size zero for that dimension, the output of the cat would also have zero for the dimension. Would that be expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case it should be caught by the inputs.empty()
condition at line 678, since all of the inputs should be empty. Actually though, I just realized that assumes that if one of the inputs has constant size 0 in one dimension that all the other inputs will have constant size zero in the same dimension. If some are symbolic this definition should be proof that they're equal by exact mapping, but they won't be detected at line 660. I guess what I should do instead is fire this condition if any input has a zero non-cat dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if one of the inputs has constant size 0 in one dimension that all the other inputs will have constant size zero in the same dimension.
Yeah, this was what I was thinking about. I'm not sure if that's actually allowed. Is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could have something like this I think:
tv0[ 2, 0 ]
tv1[ 3, i0 ]
tv2 = cat({tv0, tv1}, /*axis=*/0)
In this case we would normally exact map tv0->axis(1)
with tv1->axis(1)
so that i0
must be 0 or else we'll hit an error in ExpressionEvaluator::propagateBoundValuesThroughExactMaps()
. But now since we'll just translate this to full({5, 0})
I don't think there's any such constraint, so maybe it would be legal to pass in a tv1
with shape (3,2) and we would not hit an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, makes sense.
The failure in jit_binary_tests_17_A100_1/3 is real. I have a fix but am waiting to push until the codediff is done. |
!test --diff |
In #3292, a
PadOp
is created for a pad of two inputs, one of which has zero size in the cat dimension. This caused an error when we replaced the empty input with aFullOp
output. That is addressed in the remove_empty pass by PR #3301. This PR aims to additionally simplify the fusion definition when we have concrete sizes. In particular, arguments topad
andcat
are inspected for empty dimensions and if found, we avoid usingPadOp
orCatOp
unless necessary. The behavior is demonstrated in the includedCatOfEmpty
test: