Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduction ops don't support dynamic dims #3573

Open
renxida opened this issue Jul 30, 2024 · 0 comments
Open

Reduction ops don't support dynamic dims #3573

renxida opened this issue Jul 30, 2024 · 0 comments

Comments

@renxida
Copy link
Collaborator

renxida commented Jul 30, 2024

around 10%1 of the total failures in the iree onnx node tests are caused by "failed to legalize operation 'torch.aten.sum.dim_IntList' that was explicitly marked illegal". Every single case i checked was due to a dynamic dim being fed to line 488 in Reduction.cpp here

I thought it might be a problem with the test cases, but per the discusssion at nod-ai/SHARK-TestSuite#308 this does not seem to be the case.

We should figure out how to support dynamic dimes in reduction ops. The error are thrown during computeReductionOpInfoForDimVariantOp

    bool isNoneOrEmptyDimList = isa<Torch::NoneType>(op.getDim().getType());
    if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
      // Fix negative dimensions, if any, before adding to the list.
      for (int64_t dim : dimList) {
        dim = toPositiveDim(dim, inputType.getRank());
        // Drop invalid dimensions
        if (isValidDim(dim, inputType.getRank()))
          opInfo.dimSet.insert(dim);
      }
      if (dimList.empty())
        isNoneOrEmptyDimList = true;
    } else if (matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
      dim = toPositiveDim(dim, inputType.getRank());
      if (!isValidDim(dim, inputType.getRank()))
        return rewriter.notifyMatchFailure(
            op, "`dim` argument must be valid, invalid received.");
      opInfo.dimSet.insert(dim);
    } else if (!isNoneOrEmptyDimList) {
      return rewriter.notifyMatchFailure(
          op, "`dim` argument must be a constant int list or None");
    }

and the computed ReductionOfInfo is used in createReductionLinalgGeneric

Value torch_to_linalg::createReductionLinalgGeneric(
    OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
  auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());

  // Get the result shape by obtaining the size of each
  // dimension in the input tensor that is not getting reduced.
  // If `opInfo.keepDim` is true, the rank of the output tensor
  // is kept the same as the rank of the input tensor, and the
  // reduced dimensions are set to have size 1.
  auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
  SmallVector<Value> resultShape;
  for (int64_t i = 0; i < inputType.getRank(); i++) {
    auto currentDimSize = b.create<tensor::DimOp>(loc, opInfo.tensorOperand, i);
    if (!opInfo.dimSet.contains(i))
      resultShape.push_back(currentDimSize);
    else if (opInfo.keepDim)
      resultShape.push_back(c1);
  }

  // Create the affine expressions that will be used to
  // iterate over the input and output tensors.
  // Here we also set the type of iterator: parallel or reduction.
  SmallVector<AffineExpr> exprs;
  SmallVector<utils::IteratorType> iteratorTypes;
  SmallVector<AffineExpr> resultExprs;
  for (auto size :
       llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
    exprs.push_back(b.getAffineDimExpr(size.index()));

    if (opInfo.dimSet.contains(size.index())) {
      iteratorTypes.push_back(utils::IteratorType::reduction);
      // If `opInfo.keepDim`, create affine map to the first element
      // in the current dimension.
      if (opInfo.keepDim)
        resultExprs.push_back(b.getAffineConstantExpr(0));
    } else {
      iteratorTypes.push_back(utils::IteratorType::parallel);
      resultExprs.push_back(b.getAffineDimExpr(size.index()));
    }
  }

Footnotes

  1. by my count in https://gist.github.com/renxida/2706b7522b38d8c6271bfff95fb62609

@renxida renxida changed the title Add support for dynamic dims in reduction ops Reduction ops don't support dynamic dims Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant