From a89fa642d7d644139b13f0e47a3f7a566b9b2aa2 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 18 Mar 2021 09:48:44 -0700 Subject: [PATCH] Enable MobileBert on GPU for the linalg on tensors path. --- iree/compiler/Conversion/Common/Transforms.cpp | 5 ++--- iree/test/e2e/models/BUILD | 11 +---------- iree/test/e2e/models/CMakeLists.txt | 13 +------------ 3 files changed, 4 insertions(+), 25 deletions(-) diff --git a/iree/compiler/Conversion/Common/Transforms.cpp b/iree/compiler/Conversion/Common/Transforms.cpp index 0c4af0b9d9f0..5d7fa02a1b78 100644 --- a/iree/compiler/Conversion/Common/Transforms.cpp +++ b/iree/compiler/Conversion/Common/Transforms.cpp @@ -317,10 +317,9 @@ LogicalResult defineWorkgroupCountRegion( OpBuilder &builder, FuncOp funcOp, WorkgroupCountRegionBuilder regionBuilder) { IREE::HAL::ExecutableEntryPointOp entryPointOp = getEntryPoint(funcOp); - if (!entryPointOp) + if (!entryPointOp) { return funcOp.emitOpError("unable to find corresponding entry point op"); - if (entryPointOp.getBody()) - return entryPointOp.emitOpError("cannot override workgroup_count_region"); + } Location loc = entryPointOp.getLoc(); OpBuilder::InsertionGuard guard(builder); diff --git a/iree/test/e2e/models/BUILD b/iree/test/e2e/models/BUILD index d897e764de68..cb0403a00c34 100644 --- a/iree/test/e2e/models/BUILD +++ b/iree/test/e2e/models/BUILD @@ -63,18 +63,9 @@ iree_check_single_backend_test_suite( target_backend = "vmla", ) -iree_check_single_backend_test_suite( - name = "check_vulkan-spirv_vulkan", - srcs = CHECK_FRAMEWORK_TESTS, - driver = "vulkan", - target_backend = "vulkan-spirv", -) - iree_check_single_backend_test_suite( name = "check_linalg_on_tensors_vulkan-spirv_vulkan", - srcs = [ - "mobilenetv2_fake_weights.mlir", - ], + srcs = CHECK_FRAMEWORK_TESTS, compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", "-iree-codegen-spirv-experimental-linalg-on-tensors", diff --git a/iree/test/e2e/models/CMakeLists.txt b/iree/test/e2e/models/CMakeLists.txt index 0a8aced8b2c8..affb8cfaa23a 100644 --- a/iree/test/e2e/models/CMakeLists.txt +++ b/iree/test/e2e/models/CMakeLists.txt @@ -41,22 +41,11 @@ iree_check_single_backend_test_suite( "vmla" ) -iree_check_single_backend_test_suite( - NAME - check_vulkan-spirv_vulkan - SRCS - "bert_encoder_unrolled_fake_weights.mlir" - "mobilenetv2_fake_weights.mlir" - TARGET_BACKEND - "vulkan-spirv" - DRIVER - "vulkan" -) - iree_check_single_backend_test_suite( NAME check_linalg_on_tensors_vulkan-spirv_vulkan SRCS + "bert_encoder_unrolled_fake_weights.mlir" "mobilenetv2_fake_weights.mlir" TARGET_BACKEND "vulkan-spirv"