Skip to content

Commit

Permalink
[mhlo] Copy BroadcastOp from StableHLO to CHLO.
Browse files Browse the repository at this point in the history
Part 1-of-many
 - Create a copy of stablehlo.broadcast in CHLO, including changes in .td and .cc files.
 - Create tests for chlo.broadcast (refer tests for stablehlo.broadcast).
  • Loading branch information
subhankarshah committed Sep 1, 2022
1 parent d88d0fe commit 4b5e17e
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
1 change: 1 addition & 0 deletions stablehlo/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ add_mlir_dialect_library(ChloOps
MLIRInferTypeOpInterface
MLIRIR
MLIRQuantDialect
MLIRTensorDialect
)

add_mlir_dialect_library(StablehloRegister
Expand Down
79 changes: 79 additions & 0 deletions stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ limitations under the License.

#include "stablehlo/dialect/ChloOps.h"

#include <complex>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -448,6 +453,80 @@ LogicalResult RankSpecializationClusterOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

// TODO(b/129012527) These should be expressed as type constraints.
LogicalResult BroadcastOp::verify() {
auto sizes = broadcast_sizes();
auto sizesType = sizes.getType();
auto sizesRank = sizesType.getRank();
if (sizesRank != 1) {
return emitOpError(llvm::formatv(
"broadcast_sizes has rank {0} instead of rank 1", sizesRank));
}

return success();
}

LogicalResult BroadcastOp::inferReturnTypeComponents(
MLIRContext*, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
BroadcastOp::Adaptor adaptor(operands, attributes, regions);
Value operand = adaptor.operand();
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
if (!operandType) return failure();

Type elementTy = operandType.getElementType();
auto dimensionAttr = adaptor.broadcast_sizes();
for (int64_t size : dimensionAttr.getValues<int64_t>()) {
if (size < 0)
return emitOptionalError(location,
"Broadcast with negative dimension size ", size);
}
SmallVector<int64_t> shapeValues(dimensionAttr.getValues<int64_t>());
llvm::append_range(shapeValues, operandType.getShape());

inferredReturnShapes.emplace_back(shapeValues, elementTy);
return success();
}

LogicalResult BroadcastOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
BroadcastOp::Adaptor adaptor(operands);
Value operand = adaptor.operand();

auto operandType = operand.getType().dyn_cast<RankedTensorType>();
// Unranked tensors are not supported.
if (!operandType) return failure();

Location loc = getLoc();
SmallVector<Value, 4> shapeValues;

// Collect the broadcast sizes.
for (const auto& size : broadcast_sizes()) {
shapeValues.push_back(
builder.create<arith::ConstantIndexOp>(loc, size.getZExtValue()));
}

// Collect the operand sizes.
for (auto index : llvm::seq<int64_t>(0, operandType.getRank())) {
shapeValues.push_back(
builder.createOrFold<tensor::DimOp>(loc, operand, index));
}

reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
loc,
RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
builder.getIndexType()),
shapeValues));

return success();
}

//===----------------------------------------------------------------------===//
// TopKOp
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class CHLO_Op<string mnemonic, list<Trait> traits> :
Op<CHLO_Dialect, mnemonic, traits> {
}

class CHLO_ShapedInterfaceOp<string mnemonic, list<Trait> traits> :
CHLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]> {
}

include "stablehlo/dialect/ChloEnums.td"

class CHLO_NativeOpTrait<string name> : NativeOpTrait<name> {
Expand Down Expand Up @@ -945,4 +950,28 @@ def CHLO_DynamicReshapeOp: CHLO_Op<"dynamic_reshape", [NoSideEffect,
let results = (outs HLO_Tensor:$result);
}

def CHLO_BroadcastOp : CHLO_ShapedInterfaceOp<"broadcast",
[NoSideEffect, SameOperandsAndResultElementType, InferTensorType]> {
let summary = "Broadcast a tensor to a higher rank by prepending dimensions";
let description = [{
Broadcasts the operand tensor to a higher rank by prepending
`broadcast_sizes` to the dimensions. The current values of the operand are
copied into the other dimensions.

This is a more limited form of broadcasting, that corresponds to the XLA
client Broadcast method. For a more general form of broadcasting, see the
BroadcastInDimOp.

See https://www.tensorflow.org/xla/operation_semantics#broadcast.
}];
let arguments = (ins
HLO_Tensor:$operand,
I64ElementsAttr:$broadcast_sizes
);

let results = (outs HLO_Tensor);

let hasVerifier = 1;
}

#endif // STABLEHLO_DIALECT_CHLO_OPS
42 changes: 42 additions & 0 deletions stablehlo/tests/ops_chlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,45 @@ func.func @top_k(%arg0 : tensor<16x16xf32>) {
%0:2 = chlo.top_k(%arg0, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>)
return
}

// -----

// CHECK-LABEL: func @broadcast
func.func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
%0 = "chlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
func.return %0 : tensor<1x2x3xi32>
}

// -----

func.func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{broadcast_sizes has rank 2 instead of rank 1}}
%0 = "chlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
func.return %0 : tensor<1x2x3xi32>
}

// -----

func.func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{'chlo.broadcast' op inferred type(s) 'tensor<2x3xi32>' are incompatible with return type(s) of operation 'tensor<1x2x3xi32>'}}
%0 = "chlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
func.return %0 : tensor<1x2x3xi32>
}

// -----

func.func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{'chlo.broadcast' op inferred type(s) 'tensor<2x3xi32>' are incompatible with return type(s) of operation 'tensor<1x3xi32>'}}
%0 = "chlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32>
func.return %0 : tensor<1x3xi32>
}

// -----

func.func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{'chlo.broadcast' op inferred type(s) 'tensor<2x3xi32>' are incompatible with return type(s) of operation 'tensor<2x1xi32>'}}
%0 = "chlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32>
func.return %0 : tensor<2x1xi32>
}

// -----

0 comments on commit 4b5e17e

Please sign in to comment.