From 179195868f3b3a197733a5014cc7e80b7bb88c7e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 4 Aug 2023 09:20:17 -0700 Subject: [PATCH] Add simple pass to turn dense attributes into dense_resource attributes. (#14574) --- .../Dialect/Util/Transforms/BUILD.bazel | 1 + .../Dialect/Util/Transforms/CMakeLists.txt | 1 + .../Util/Transforms/ImportResources.cpp | 206 ++++++++++++++++++ .../compiler/Dialect/Util/Transforms/Passes.h | 3 + .../Dialect/Util/Transforms/Passes.td | 21 ++ .../Dialect/Util/Transforms/test/BUILD.bazel | 1 + .../Util/Transforms/test/CMakeLists.txt | 1 + .../Transforms/test/import_resources.mlir | 89 ++++++++ 8 files changed, 323 insertions(+) create mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel index c0d0b57145d0..524113ac1255 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel @@ -25,6 +25,7 @@ iree_compiler_cc_library( "FuseGlobals.cpp", "HoistIntoGlobals.cpp", "IPO.cpp", + "ImportResources.cpp", "PassDetail.h", "Passes.cpp", "Patterns.cpp", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index 284af5f52bc0..8ae9513f19b4 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt @@ -28,6 +28,7 @@ iree_cc_library( "FuseGlobals.cpp" "HoistIntoGlobals.cpp" "IPO.cpp" + "ImportResources.cpp" "PassDetail.h" "Passes.cpp" "Patterns.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp new file mode 100644 index 000000000000..625f680852fb --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp @@ -0,0 +1,206 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +#define DEBUG_TYPE "iree-util-import-resources" + +namespace mlir::iree_compiler::IREE::Util { + +namespace { + +// TODO: Just use the DenseResourceElementsAttr::get() +// builder once https://reviews.llvm.org/D157064 lands. +class DenseBlobResourceElementsAttr : public DenseResourceElementsAttr { +public: + using DenseResourceElementsAttr::get; +}; + +template +static void copyIntAttrIntoBlob(AsmResourceBlob &blob, + DenseIntElementsAttr attr) { + ArrayRef data = blob.getDataAs(); + MutableArrayRef rwData = MutableArrayRef( + const_cast(data.data()), data.size()); + ArrayRef rawSrcData = attr.getRawData(); + if (rawSrcData.size() == blob.getData().size()) { + // Memcpy. + std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size()); + } else { + // Slow. + size_t index = 0; + for (APInt value : attr.getValues()) { + rwData[index++] = value.extractBitsAsZExtValue(numBits, 0); + } + } +} + +template +static void copyFPAttrIntoBlob(AsmResourceBlob &blob, + DenseFPElementsAttr attr) { + ArrayRef data = blob.getDataAs(); + MutableArrayRef rwData = MutableArrayRef( + const_cast(data.data()), data.size()); + ArrayRef rawSrcData = attr.getRawData(); + if (rawSrcData.size() == blob.getData().size()) { + // Memcpy. + std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size()); + } else { + // Slow. + size_t index = 0; + for (APFloat value : attr.getValues()) { + rwData[index++] = + value.bitcastToAPInt().extractBitsAsZExtValue(numBits, 0); + } + } +} + +class ImportResourcesPass : public ImportResourcesBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + llvm::DenseMap replacements; + + getOperation()->walk([&](Operation *op) { + bool updated = false; + SmallVector attrs(op->getAttrs()); + for (auto &attr : attrs) { + if (auto elements = llvm::dyn_cast(attr.getValue())) { + // Already seen? + auto it = replacements.find(elements); + if (it != replacements.end()) { + LLVM_DEBUG(llvm::dbgs() + << ":: Replacing already encountered attr of " + << elements.getType() << "\n"); + attr.setValue(it->second); + updated = true; + continue; + } + + // Convert. + if (shouldConvertElements(elements)) { + LLVM_DEBUG(llvm::dbgs() << ":: Converting elements attr of " + << elements.getType() << "\n"); + if (auto replacement = convertElementsAttr(elements)) { + attr.setValue(replacement); + replacements[elements] = replacement; + updated = true; + } else { + LLVM_DEBUG(llvm::dbgs() << " Failed to convert\n"); + } + } + } + } + if (updated) + op->setAttrs(attrs); + }); + LLVM_DEBUG(llvm::dbgs() << "DONE CONVERTING RESOURCES\n"); + } + + static bool shouldConvertElements(ElementsAttr attr) { + if (llvm::isa(attr)) { + // DenseElementsAttr encodes arbitrary dimension + // splats whereas DenseResourceElementsAttr does not. + return !attr.isSplat(); + } + + return false; + } + + static ElementsAttr convertElementsAttr(ElementsAttr elementsAttr) { + auto st = llvm::cast(elementsAttr.getType()); + auto elementType = st.getElementType(); + auto numElements = elementsAttr.getNumElements(); + auto bitWidth = elementType.getIntOrFloatBitWidth(); + AsmResourceBlob blob; + if (auto attr = llvm::dyn_cast(elementsAttr)) { + switch (bitWidth) { + case 1: + blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyIntAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_i1", + std::move(blob)); + case 8: + blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyIntAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_i8", + std::move(blob)); + case 16: + blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyIntAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_i16", + std::move(blob)); + case 32: + blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyIntAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_i32", + std::move(blob)); + case 64: + blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyIntAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_i64", + std::move(blob)); + default: + return {}; + } + } else if (auto attr = llvm::dyn_cast(elementsAttr)) { + AsmResourceBlob blob; + switch (bitWidth) { + case 8: + blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyFPAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_f8", + std::move(blob)); + case 16: + blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyFPAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_f16", + std::move(blob)); + case 32: + blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyFPAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_f32", + std::move(blob)); + case 64: + blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64, + /*dataIsMutable=*/true); + copyFPAttrIntoBlob(blob, attr); + return DenseBlobResourceElementsAttr::get(st, "dense_elements_f64", + std::move(blob)); + default: + return {}; + } + } + return {}; + } +}; + +} // namespace + +std::unique_ptr> createImportResourcesPass() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler::IREE::Util diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h index 39c219f2abd8..abf1c190467b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h @@ -29,6 +29,9 @@ std::unique_ptr> createPropagateSubrangesPass(); std::unique_ptr> createSimplifyGlobalAccessesPass(); std::unique_ptr> createStripDebugOpsPass(); +// Resource Management. +std::unique_ptr> createImportResourcesPass(); + // Type conversion. std::unique_ptr> createDemoteI64ToI32Pass(); std::unique_ptr> createDemoteF32ToF16Pass(); diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td index 0da39c070b50..873b1ab07b1c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td @@ -101,6 +101,27 @@ def SimplifyGlobalAccesses : }]; } +//===----------------------------------------------------------------------===// +// Resource Management +//===----------------------------------------------------------------------===// + +def ImportResources : Pass<"iree-util-import-resources", ""> { + let summary = "Imports IR with arbitrary large-data into resources that IREE can manage efficiently"; + let description = [{ + MLIR has many interesting ways to store large constants, most of which + derive from *ElementsAttr. Given the uniquing/inline behavior, this exacts + very large runtime and memory overhead costs. + + This is a temporary pass to convert a majority of the legacy + DenseElementsAttr attributes to DenseResourceElementsAttr. Ideally this + is done at the source (frontend), but this pass is provided to aid + transition and testing by doing a manual conversion with iree-opt. + }]; + let constructor = [{ + mlir::iree_compiler::IREE::Util::createImportResourcesPass() + }]; +} + //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel index db4655d0396a..c473c6460fee 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel @@ -27,6 +27,7 @@ iree_lit_test_suite( "fuse_globals.mlir", "hoist_into_globals.mlir", "hoist_into_globals_linalg.mlir", + "import_resources.mlir", "ipo.mlir", "promote_bf16_to_f32.mlir", "promote_f16_to_f32.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt index f855c45958f8..dfc917f9237c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt @@ -25,6 +25,7 @@ iree_lit_test_suite( "fuse_globals.mlir" "hoist_into_globals.mlir" "hoist_into_globals_linalg.mlir" + "import_resources.mlir" "ipo.mlir" "promote_bf16_to_f32.mlir" "promote_f16_to_f32.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir new file mode 100644 index 000000000000..e8b3b503eda0 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir @@ -0,0 +1,89 @@ +// RUN: iree-opt --split-input-file --iree-util-import-resources %s | FileCheck %s + +// CHECK-LABEL: func.func @constant_splat_i64 +func.func @constant_splat_i64() -> tensor<4xi64> { + // Splats should not convert. + // CHECK-NEXT: constant dense<123> + %c123 = arith.constant dense<123> : tensor<4xi64> + return %c123 : tensor<4xi64> +} + +// ----- +// CHECK-LABEL: func.func @dense_i1 +func.func @dense_i1() -> tensor<4xi1> { + // CHECK: dense_resource + %c123 = arith.constant dense<[true, false, false, true]> : tensor<4xi1> + return %c123 : tensor<4xi1> +} + +// CHECK: dense_elements_i1: "0x4000000001000001" + +// ----- +// CHECK-LABEL: func.func @dense_i8 +func.func @dense_i8() -> tensor<4xi8> { + // CHECK: dense_resource + %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi8> + return %c123 : tensor<4xi8> +} + +// CHECK: dense_elements_i8: "0x400000000102037F" + +// ----- +// CHECK-LABEL: func.func @dense_i16 +func.func @dense_i16() -> tensor<4xi16> { + // CHECK: dense_resource + %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi16> + return %c123 : tensor<4xi16> +} + +// CHECK: dense_elements_i16: "0x400000000100020003007F00" + +// ----- +// CHECK-LABEL: func.func @dense_i32 +func.func @dense_i32() -> tensor<4xi32> { + // CHECK: dense_resource + %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi32> + return %c123 : tensor<4xi32> +} + +// CHECK: dense_elements_i32: "0x400000000100000002000000030000007F000000" + +// ----- +// CHECK-LABEL: func.func @dense_i64 +func.func @dense_i64() -> tensor<4xi64> { + // CHECK: dense_resource + %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi64> + return %c123 : tensor<4xi64> +} + +// CHECK: dense_elements_i64: "0x400000000100000000000000020000000000000003000000000000007F00000000000000" + +// ----- +// CHECK-LABEL: func.func @dense_f16 +func.func @dense_f16() -> tensor<4xf16> { + // CHECK: dense_resource + %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf16> + return %c123 : tensor<4xf16> +} + +// CHECK: dense_elements_f16: "0x40000000663C66409A420000" + +// ----- +// CHECK-LABEL: func.func @dense_f32 +func.func @dense_f32() -> tensor<4xf32> { + // CHECK: dense_resource + %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf32> + return %c123 : tensor<4xf32> +} + +// CHECK: dense_elements_f32: "0x40000000CDCC8C3FCDCC0C403333534000000000" + +// ----- +// CHECK-LABEL: func.func @dense_f64 +func.func @dense_f64() -> tensor<4xf64> { + // CHECK: dense_resource + %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf64> + return %c123 : tensor<4xf64> +} + +// CHECK: dense_elements_f64: "0x400000009A9999999999F13F9A999999999901406666666666660A400000000000000000"