Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhanW authored Jan 15, 2024
1 parent dc37616 commit 10acea7
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 21 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 4076 files
16 changes: 9 additions & 7 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
rewriter.getDenseI32ArrayAttr({multiplier}),
rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false));
rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
rewriter.getBoolAttr(false));

return rescale_op.getResult();
}
Expand Down Expand Up @@ -86,8 +87,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
rewriter, op->getLoc(), output_type, conv_val,
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
rewriter.getDenseI32ArrayAttr({multiplier}),
rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true),
rewriter.getBoolAttr(false));

return rescale_op.getResult();

Expand All @@ -96,7 +98,7 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
// Per-channel quantization
SmallVector<int32_t> multiplier_arr;
SmallVector<int32_t> shift_arr;
SmallVector<int8_t> shift_arr;

SmallVector<double> weight_scale_arr(
weight_per_channel_qtype.getScales().begin(),
Expand All @@ -115,14 +117,14 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
scale_width);

multiplier_arr.push_back(multiplier);
shift_arr.push_back(shift);
shift_arr.push_back(static_cast<int8_t>(shift));
}

auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, conv_val,
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
rewriter.getDenseI32ArrayAttr(multiplier_arr),
rewriter.getDenseI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
rewriter.getDenseI8ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));

return rescale_op.getResult();
Expand Down
8 changes: 4 additions & 4 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INTM1:.*]] = torch.constant.int -1
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK: %[[NEG_STEP:.*]] = torch.aten.__range_length %[[INT1]], %[[INT3]], %[[INTM1]] : !torch.int, !torch.int, !torch.int -> !torch.int
// CHECK: return %[[INT2]], %[[INT2]], %[[INT1]], %[[NEG_STEP]] : !torch.int, !torch.int, !torch.int, !torch.int
func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) {
Expand Down
4 changes: 2 additions & 2 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor

// CHECK-LABEL: func.func @torch.aten.type_as$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
// CHECK: return %[[VAR]] : !torch.tensor
Expand Down
14 changes: 7 additions & 7 deletions test/Dialect/Torch/simplify-shape-calculations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ func.func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !tor
// CHECK-LABEL: func.func @fully_unroll_prim_loop$unroll(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: torch.shape.calculate.yield %[[ARG0]] : !torch.vtensor
// CHECK: } shapes {
Expand Down Expand Up @@ -375,8 +375,8 @@ func.func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.
// missing.
// CHECK-LABEL: func.func @basic_integration(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?],unk> -> !torch.vtensor<[?,?],unk>
// CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor<[?,?],unk>
Expand Down Expand Up @@ -410,8 +410,8 @@ func.func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor
// CHECK-LABEL: func.func @fold_prim_unchecked_cast_op(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor {
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK-DAG: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK-DAG: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.shape.calculate {
// CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[VAL_0]] : !torch.vtensor to !torch.vtensor<[?,?],unk>
// CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[?,?],unk>
Expand Down

0 comments on commit 10acea7

Please sign in to comment.