Skip to content

Commit

Permalink
Bump stablehlo to openxla/stablehlo@fd52182 (#2821)
Browse files Browse the repository at this point in the history
With the recent LLVM integrate and changes from
llvm/llvm-project#78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
    rewriter.startRootUpdate(op);
    ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
      rewriter.finalizeRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
      rewriter.cancelRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
    rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
    ~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```

I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to openxla/stablehlo#1918 fixes it.

It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test 

...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference                                                                               
  %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>             
       ^                                                                                                                                                                                            
LLVM ERROR: Failed to infer result type(s).               
```

Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
  • Loading branch information
sjain-stanford committed Jan 31, 2024
1 parent 54e2587 commit 8a17c98
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 109 deletions.
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 528 files
18 changes: 9 additions & 9 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
adaptor.getAlpha(), outElemTy);
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions);
}

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
Expand Down Expand Up @@ -424,7 +424,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Expand Down Expand Up @@ -542,7 +542,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
} else {
return op.emitError("operator haven't been supported");
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastCompareOp>(
op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr,
compareTypeAttr);
Expand Down Expand Up @@ -570,7 +570,7 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern<AtenOpT> {
Value rhs =
hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType);

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
Expand Down Expand Up @@ -757,7 +757,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers));
rewriter.getDenseI64ArrayAttr(dimensionNumbers));
}
return success();
}
Expand Down Expand Up @@ -887,7 +887,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Expand Down Expand Up @@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(

Value window =
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
DenseIntElementsAttr broadcastDimensions;
DenseI64ArrayAttr broadcastDimensions;
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
broadcastDimensions);
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, mulOut, start,
Expand Down Expand Up @@ -1721,7 +1721,7 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
rewriter.create<shape::ShapeOfOp>(op->getLoc(), adaptor.getSelf());
Value bcastScalar = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(), outType, scalarTensor, shapeTensor,
rewriter.getI64TensorAttr({}));
rewriter.getDenseI64ArrayAttr({}));
rewriter.replaceOp(op, bcastScalar);
return success();
}
Expand Down
10 changes: 6 additions & 4 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
return failure();

auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0}));
op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}),
elementTy);

Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -510,7 +511,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, input, gatherIndicies, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
rewriter.getDenseI64ArrayAttr(sliceSizes));
return success();
}

Expand Down Expand Up @@ -666,7 +667,8 @@ LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite(
/*indexVectorDim=*/indexVecDim);

auto stablehloScatterOp = rewriter.create<stablehlo::ScatterOp>(
loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false);
loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers,
false, false);

// config update computation function: just return the element from src.
Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock();
Expand Down Expand Up @@ -833,7 +835,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, resultType, input, finalIndexTensor, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
rewriter.getDenseI64ArrayAttr(sliceSizes));
return success();
}

Expand Down
34 changes: 12 additions & 22 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
RankedTensorType outTy =
RankedTensorType::get(shape, tensorTy.getElementType());

RankedTensorType attrTy =
RankedTensorType::get({static_cast<int64_t>(broadcastDims.size())},
rewriter.getIntegerType(64));
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
auto broadcastAttr = rewriter.getDenseI64ArrayAttr(broadcastDims);

auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, stablehloShape, broadcastAttr);
Expand Down Expand Up @@ -549,8 +546,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {

// Prepare for transposed convolution
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr stablehloStride =
rewriter.getI64TensorAttr(stablehloStrideVec);
auto stablehloStride = rewriter.getDenseI64ArrayAttr(stablehloStrideVec);
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
for (int i = 0; i < nSpatialDims; ++i) {
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
Expand All @@ -563,15 +559,15 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
stablehloPaddingVec);
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
DenseIntElementsAttr stablehloLhsDilation =
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
auto stablehloLhsDilation =
rewriter.getDenseI64ArrayAttr(stablehloLhsDilationVec);
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(),
stablehloRhsDilationVec.begin());
DenseIntElementsAttr stablehloRhsDilation =
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
auto stablehloRhsDilation =
rewriter.getDenseI64ArrayAttr(stablehloRhsDilationVec);

DenseElementsAttr windowReversal;
DenseBoolArrayAttr windowReversal;
ArrayAttr precisionConfig;

SmallVector<int64_t> spatialDims;
Expand Down Expand Up @@ -614,10 +610,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
int64_t nDims = outType.getRank();

// Get stablehlo::ConvolutionOp attributes
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(stride.size())},
rewriter.getI64Type()),
stride);
auto stablehloWindowStride = rewriter.getDenseI64ArrayAttr(stride);
std::vector<int64_t> stablehloPaddingVec;
for (size_t i = 0; i < padding.size(); i++) {
stablehloPaddingVec.emplace_back(padding[i]);
Expand All @@ -628,10 +621,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
rewriter.getI64Type()),
stablehloPaddingVec);
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(dilation.size())},
rewriter.getI64Type()),
dilation);
auto stablehloRhsDilation = rewriter.getDenseI64ArrayAttr(dilation);
SmallVector<int64_t> spatialDimensions;
for (int64_t i = 2; i < nDims; i++) {
spatialDimensions.emplace_back(i);
Expand All @@ -648,8 +638,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
/*outputSpatialDimensions=*/spatialDimensions);

// stablehlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr stablehloLhsDilation;
DenseElementsAttr windowReversal;
DenseI64ArrayAttr stablehloLhsDilation;
DenseBoolArrayAttr windowReversal;
ArrayAttr precisionConfig;

auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
Expand Down Expand Up @@ -781,7 +771,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
options.dimSizeIndexBits);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, outTy, stablehloConvResult, bias, bcastDimensions);
return success();
Expand Down
73 changes: 18 additions & 55 deletions lib/Conversion/TorchToStablehlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -242,19 +233,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -453,20 +435,10 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
Value initVal =
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -508,7 +480,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
.value();
}
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
return success();
Expand All @@ -528,7 +500,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(),
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({}));

Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
Expand Down Expand Up @@ -599,19 +571,10 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
stablehloPadding[dim * 2] = inputShape[dim] - 1;

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
initValue,
initIndex,
},
rewriter.getI64TensorAttr(dim));
rewriter.getDenseI64ArrayAttr(dim));

Block &block = stablehloReduceOp.getBody().emplaceBlock();

Expand Down Expand Up @@ -412,7 +412,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(

llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -473,7 +473,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -535,7 +535,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
Expand Down Expand Up @@ -625,7 +625,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(

llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));

Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -729,7 +729,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(

auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue,
rewriter.getI64TensorAttr(dims));
rewriter.getDenseI64ArrayAttr(dims));

Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down Expand Up @@ -848,7 +848,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
ord, nullptr);

auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims));
op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));

Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
Expand Down
Loading

0 comments on commit 8a17c98

Please sign in to comment.