Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWQ: Separate the AWQ kernels to separate repository #279

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up

- Your GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported.
- Your CUDA version must be CUDA 11.8 or later.
- Requires installing [AutoAWQ kernels](https://github.com/casper-hansen/AutoAWQ_kernels).

### Install from PyPi

Expand All @@ -49,8 +50,6 @@ pip install https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/au

### Build from source

Build time can take 10-20 minutes. Download your model while you install AutoAWQ.

```
git clone https://github.com/casper-hansen/AutoAWQ
cd AutoAWQ
Expand Down
4 changes: 2 additions & 2 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


try:
import ft_inference_engine
import awq_ft_ext
FT_INSTALLED = True
except:
FT_INSTALLED = False
Expand Down Expand Up @@ -214,7 +214,7 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = ft_inference_engine.single_query_attention(
attention_weight = awq_ft_ext.single_query_attention(
xq, # query
xk, # key
xv, # value
Expand Down
6 changes: 3 additions & 3 deletions awq/modules/fused/mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
import awq_inference_engine
import awq_ext
import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV

Expand Down Expand Up @@ -28,10 +28,10 @@ def __init__(
self.down_proj = down_proj

if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.linear = awq_ext.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.linear = awq_ext.gemm_forward_cuda
self.group_size = 8

self.activation = activation
Expand Down
4 changes: 2 additions & 2 deletions awq/modules/fused/norm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
import awq_inference_engine
import awq_ext

class FasterTransformerRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6):
Expand All @@ -10,5 +10,5 @@ def __init__(self, weight, eps=1e-6):

def forward(self, x):
output = torch.empty_like(x)
awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output
8 changes: 4 additions & 4 deletions awq/modules/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import torch
import torch.nn as nn
import awq_inference_engine # with CUDA kernels
import awq_ext # with CUDA kernels


def make_divisible(c, divisor):
Expand Down Expand Up @@ -102,7 +102,7 @@ def forward(self, x):
if input_dtype != torch.float16:
x = x.half()

out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out = awq_ext.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)

if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
Expand Down Expand Up @@ -210,9 +210,9 @@ def forward(self, x):
inputs = inputs.half()

if inputs.shape[0] > 8:
out = awq_inference_engine.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters)
out = awq_ext.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters)
else:
out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size)
out = awq_ext.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size)

if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
Expand Down
257 changes: 0 additions & 257 deletions awq_cuda/attention/cuda_bf16_fallbacks.cuh

This file was deleted.

23 changes: 0 additions & 23 deletions awq_cuda/attention/cuda_bf16_wrapper.h

This file was deleted.

Loading