Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request for making Matmul/Zero-Fill microkernel aware about offset and strides #1478

Open
Abhishek-Varma opened this issue May 10, 2024 · 4 comments

Comments

@Abhishek-Varma
Copy link
Contributor

Currently the Matmul Ukernel is offset aware.

This was trivial to add, but now we also want to make the ukernel aware about the strides too.

Here is the e2e IR log from iree-amd-aie that necessitates the requirement of the same - Ukernel_bf16_IR_log.

In the above IR you'd see that IREE is going to generate invocations like :-

func.call @zero_bf16(%base_buffer, %c0, %c16384, %c16384) : (memref<bf16, 2 : i32>, index, index, index) -> ()

Previously it worked because the invocation was func.call @zero_bf16(%base_buffer, %c0) : (memref<bf16, 2 : i32>, index) -> () for which the current ukernel (linked attached above) is aligned.

@erwei-xilinx
Copy link
Collaborator

Copying @jackl-xilinx, @jgmelber, @denolf and @stephenneuendorffer.

@erwei-xilinx
Copy link
Collaborator

Hi @Abhishek-Varma, what do the two 16384 represent in this zero fill example?

@Abhishek-Varma
Copy link
Contributor Author

Hi @erwei-xilinx

They represent the strides of the outer two dimensions obtained from the subview of memref of the memref.

%subview_11 = memref.subview %alloc_3[0, 0, %arg3, %arg2, 0, 0] [1, 1, 16, 16, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x32x32x4x4xbf16, 2 : i32> to memref<1x1x16x16x4x4xbf16, strided<[16384, 16384, 512, 16, 4, 1], offset: ?>, 2 : i32>
%base_buffer, %offset, %sizes:6, %strides:6 = memref.extract_strided_metadata %subview_11 : memref<1x1x16x16x4x4xbf16, strided<[16384, 16384, 512, 16, 4, 1], offset: ?>, 2 : i32> -> memref<bf16, 2 : i32>, index, index, index, index, index, index, index, index, index, index, index, index, index
func.call @zero_bf16(%base_buffer, %offset, %strides#0, %strides#1) : (memref<bf16, 2 : i32>, index, index, index) -> ()

@MaheshRavishankar
Copy link
Contributor

Addin some more information of what the micro kernel needs to support. It is useful to look at the IR for one of the ukernel ops before and after lower to function calls. (gotten from here)

Before lowering to function call

        %subview_11 = memref.subview %alloc_3[0, 0, %arg3, %arg2, 0, 0] [1, 1, 16, 16, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x32x32x4x4xbf16, 2 : i32> to memref<1x1x16x16x4x4xbf16, strided<[16384, 16384, 512, 16, 4, 1], offset: ?>, 2 : i32>
        iree_codegen.ukernel.generic "zero_bf16" outs(%subview_11 : memref<1x1x16x16x4x4xbf16, strided<[16384, 16384, 512, 16, 4, 1], offset: ?>, 2 : i32>) fn_def_attrs {link_with = "mm.o"} strided_outer_dims(2)

Notice that in the outs subview only the inner two dimensions are contiguous. This makes sense by looking at the subview itself. As a result the micro kernel actually needs 4 of the outer strides. So the value of strided_outer_dims should be 4 insteaf of 2. All these four strides are needed and need to be accounted for if the microkernel has to access data correctly. So the corrected version of this should be

        %subview_11 = memref.subview %alloc_3[0, 0, %arg3, %arg2, 0, 0] [1, 1, 16, 16, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x32x32x4x4xbf16, 2 : i32> to memref<1x1x16x16x4x4xbf16, strided<[16384, 16384, 512, 16, 4, 1], offset: ?>, 2 : i32>
        iree_codegen.ukernel.generic "zero_bf16" outs(%subview_11 : memref<1x1x16x16x4x4xbf16, strided<[16384, 16384, 512, 16, 4, 1], offset: ?>, 2 : i32>) fn_def_attrs {link_with = "mm.o"} strided_outer_dims(4)

After lowering to function call the current code (with strided_outer_dims as 2) is

        %base_buffer, %offset, %sizes:6, %strides:6 = memref.extract_strided_metadata %subview_11 : memref<1x1x16x16x4x4xbf16, strided<[16384, 16384, 512, 16, 4, 1], offset: ?>, 2 : i32> -> memref<bf16, 2 : i32>, index, index, index, index, index, index, index, index, index, index, index, index, index
        func.call @zero_bf16(%base_buffer, %offset, %strides#0, %strides#1) : (memref<bf16, 2 : i32>, index, index, index) -> ()

In reality it should be at least

        %base_buffer, %offset, %sizes:6, %strides:6 = memref.extract_strided_metadata %subview_11 : memref<1x1x16x16x4x4xbf16, strided<[16384, 16384, 512, 16, 4, 1], offset: ?>, 2 : i32> -> memref<bf16, 2 : i32>, index, index, index, index, index, index, index, index, index, index, index, index, index
        func.call @zero_bf16(%base_buffer, %offset, %strides#0, %strides#1, %strides#2, %strides#3) : (memref<bf16, 2 : i32>, index, index, index, index, index) -> ()

The micro kernel needs to use the offset and strides to access the data correctly. Basically the 6-d tensor access of outs[i, j, k, l, m, n] has to be accessed as outs[offset + i * strides#0 + j * strides#1 + k * strides#2 + l * strides#3 + m * 4 + n] . Note that this also is capturing the agreement between the microkernel and IREE generated code that the inner-two dimensions are contiguous in memory. If that were not the case, we would need to pass strides for those as well.

cc @newling

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants