Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CPU implementation for torch._int_mm (s8*s8->s32) #121792

Closed
wants to merge 8 commits into from

Conversation

Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Mar 13, 2024

Fixes #121647

Description
Currently, the op torch._int_mm only supports CUDA device. This PR adds CPU implementation for it.
Besides the request from the issue, this op may also be useful for planned CPU implementations of LLM.int8() in Bitsandbytes.

The implementation prefers mkldnn (oneDNN) kernels. If mkldnn is not available, a reference implementation with nested for loops is used.

Test plan
python test/test_linalg.py -k test__int_mm_cpu

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

Copy link

pytorch-bot bot commented Mar 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121792

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit b6c3fb8 with merge base ae983d2 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: linalg_frontend release notes category label Mar 13, 2024
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Mar 13, 2024
@Xia-Weiwen Xia-Weiwen added the intel This tag is for PR from Intel label Mar 13, 2024
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UT failing..

aten/src/ATen/native/LinearAlgebra.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to add

#include <ATen/ops/_int_mm_native.h>
#include <ATen/ops/_int_mm_out_native.h>

in LinearAlgebra.cpp to get rid of the clang build errors.

aten/src/ATen/native/mkldnn/Matmul.cpp Outdated Show resolved Hide resolved
@@ -5866,6 +5866,34 @@ def _gen_pair(m, k, n):
r"Expected result.size\(0\) to be 17 but got 16",
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))

@onlyCPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should expand the existing test case test__int_mm instead of creating a new one for cpu.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case for CUDA has many restrictions and checks because the CUDA implementation has many limitations of shape and CUDA version, etc. However, the CPU implementation does not have those limitations. So, it will be much easier to separate the tests for CUDA and CPU. Do you think it's OK? Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you still extend the CUDA case, and add some CPU-only shapes to the test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lezcano Sorry I did not notice this comment. Do I still need to combine the CPU and CUDA test cases?

ideep::tensor::data_type::s32,
result.strides().vec()},
result.data_ptr());
// Create primitive desc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought you would go directly with mkldnn_gemm_s8s8s32: https://oneapi-src.github.io/oneDNN/v0/group__c__api__blas.html#gac1869eab851b572350fb450c50c61626

which one has better performance, or are they the same ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I have run benchmarks locally to compare the implementations with BLAS API and primitive API. In most cases, BLAS API showed better performance. However, the BLAS API requires input buffers to be contiguous. So, the current dispatching rule is that if input buffers are contiguous, the BLAS API is used; otherwise, the primitive API is used. Do you think it's OK? Thanks.

@Xia-Weiwen
Copy link
Collaborator Author

UT failing..

UT failures are fixed. Thanks.

@Xia-Weiwen
Copy link
Collaborator Author

you need to add

#include <ATen/ops/_int_mm_native.h>
#include <ATen/ops/_int_mm_out_native.h>

in LinearAlgebra.cpp to get rid of the clang build errors.

Thanks. It's added.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor points only. Feel free to merge after addressing them

@@ -3506,5 +3508,63 @@ Tensor _weight_int8pack_mm_cpu(
return C;
}

Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) {
TORCH_CHECK(self.dim() == 2, __func__, ": Expected self to be of dimension 2 but got ", self.dim());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__func__ is not standard. Better define a constexpr at the top. Also, these are user facing names. They should not use internal names of functions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have defined a string "int_mm_out_cpu" without the leading underscore.

aten/src/ATen/native/LinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/mkldnn/Matmul.cpp Outdated Show resolved Hide resolved
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review March 18, 2024 01:23
@Xia-Weiwen
Copy link
Collaborator Author

Hi @lezcano I encountered this CI failure:
image
Do you have any idea what is about? Thanks!

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 18, 2024
@Xia-Weiwen
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 19, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: .github/workflows/trunk.yml / macos-12-py3-arm64-mps / test (mps, 1, 1, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@lezcano
Copy link
Collaborator

lezcano commented Mar 19, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@maktukmak
Copy link

@Xia-Weiwen, which CPU flags are needed to use this function on Intel CPUs? Also, what will happen when using AMD CPUs?

@Xia-Weiwen
Copy link
Collaborator Author

Hi @maktukmak This function call oneDNN BLAS API essentially, so it should support all X86 platforms.

@maktukmak
Copy link

Someone in Huggingface reported overflow on AMD Epyc 7R32. What could be the reason?

@dacorvo
Copy link

dacorvo commented Sep 19, 2024

Example code to reproduce the issue on AMD Epyc 7R32 (typically available on AWS cloud g5 instances).

import pytest  
import torch  

@pytest.mark.parametrize("device", ['cpu', 'cuda'])
@pytest.mark.parametrize("m", [32, 64])
@pytest.mark.parametrize("k", [32, 64])
@pytest.mark.parametrize("n", [32, 64])
@pytest.mark.parametrize("use_transpose_a", [True, False])
@pytest.mark.parametrize("use_transpose_b", [True, False])
@pytest.mark.parametrize("non_contig_type", [0, 1, 2])
def test__int_mm_cpu(device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type):
    
    # non_contig_type:
    # 0: the whole data buffer is contiguous (can be transposed)
    # 1: stride of one dimension is 1, but the whole buffer is not contiguous
    # 2: Neither stride is 1

    def genf_int_float(x, y, use_transpose, non_contig_type):
        if use_transpose:
            x, y = y, x
        if non_contig_type != 0:
            y = y * 2
        x_int8 = torch.randint(-128, 128, (x, y), dtype=torch.int8, device=device)
        x_float = x_int8.to(torch.float32)
        if non_contig_type == 1:
            x_int8 = x_int8[:, : y // 2]
            x_float = x_float[:, : y // 2]
        elif non_contig_type == 2:
            x_int8 = x_int8[:, ::2]
            x_float = x_float[:, ::2]
        if use_transpose:
            return x_int8.t(), x_float.t()
        return x_int8, x_float

    if non_contig_type != 0 and (m == 0 or k == 0):
        return
    a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type)
    b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type)
    c_int32 = torch._int_mm(a_int8, b_int8)
    assert torch.equal(c_int32.float(), torch.mm(a_float, b_float))
    c_int32_result = c_int32.new_empty(c_int32.size())
    torch._int_mm(a_int8, b_int8, out=c_int32_result)
    assert torch.equal(c_int32_result.float(), torch.mm(a_float, b_float))

@jgong5
Copy link
Collaborator

jgong5 commented Sep 23, 2024

Someone in Huggingface reported overflow on AMD Epyc 7R32. What could be the reason?

If so, it seems an issue in oneDNN? @vpirogov

@Xia-Weiwen
Copy link
Collaborator Author

Hi @maktukmak @dacorvo Could you give a pointer to the issue you mentioned on HuggingFace?

@dacorvo
Copy link

dacorvo commented Sep 26, 2024

Here is the link to the issue: huggingface/optimum-quanto#319

@Xia-Weiwen
Copy link
Collaborator Author

@dacorvo Thanks. Could you or @maktukmak open an issue to track this?

@Xia-Weiwen Xia-Weiwen deleted the int_mm_cpu branch September 26, 2024 09:33
@dacorvo
Copy link

dacorvo commented Sep 26, 2024

#136746

// x:s8 * w:s8 -> y:s32
// both inputs should be 2d
// In most cases, using DNNL blas API is faster but it requires a/b contiguous along one dimentsion
bool a_is_contigous = (mat1.stride(0) == 1 || mat1.stride(1) == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: contiguous

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: linalg_frontend release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch._int_mm on CPU
10 participants