Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI #55

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
get_symmetric_quantization_config,
)

import torchao.quantization.quant_api as quant_api
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.quant_api import apply_dynamic_quant
from torchao.quantization.quant_api import (
Quantizer,
TwoStepQuantizer,
Int8DynActInt4WeightGPTQQuantizer,
)
from torchao.quantization.utils import is_lm_eval_available
from pathlib import Path
from sentencepiece import SentencePieceProcessor
from model import Transformer
Expand Down Expand Up @@ -130,11 +132,19 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

@unittest.skipIf(not is_lm_eval_available(), "Skipping the test when lm_eval is not available")
def test_gptq(self):
# should be similar to TorchCompileDynamicQuantizer
# from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
# Int8DynActInt4WeightGPTQQuantizer = quant_api.Int8DynActInt4WeightGPTQQuantizer

precision = torch.bfloat16
device = "cpu"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
try:
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
except:
print("didn't find model")
return
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
Expand Down
14 changes: 3 additions & 11 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import torch.nn.functional as F
# from model import Transformer # pyre-ignore[21]
from torch.utils._pytree import tree_flatten, tree_unflatten
from .utils import is_lm_eval_available

aten = torch.ops.aten

## generate.py ##


def encode_tokens(tokenizer, string, bos=True, device="cuda"):

tokens = tokenizer.encode(string)
Expand All @@ -37,14 +37,7 @@ def model_forward(model, x, input_pos):

## eval.py ##

try:
import lm_eval # pyre-ignore[21] # noqa: F401

lm_eval_available = True
except:
lm_eval_available = False

if lm_eval_available:
if is_lm_eval_available():
try: # lm_eval version 0.4
from lm_eval.evaluator import evaluate # pyre-ignore[21]
from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21]
Expand All @@ -56,7 +49,7 @@ def model_forward(model, x, input_pos):
get_task_dict = tasks.get_task_dict
evaluate = evaluator.evaluate
else:
print("lm_eval is not installed, GPTQ may not be usable")
raise Exception("lm_eval is not installed, can't import GPTQ")

def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
model: torch.nn.Module, # pyre-ignore[11]
Expand Down Expand Up @@ -93,7 +86,6 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
input_pos = torch.arange(0, T, device=device)

# no caches in executorch llama2 7b model?
print("setting up cache")
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

Expand Down
9 changes: 5 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,18 @@ def replace_conv2d_1x1(conv):
)


from .GPTQ import lm_eval_available
from .utils import is_lm_eval_available

if lm_eval_available:
print("lm_eval_available:", is_lm_eval_available())
if is_lm_eval_available():
from .GPTQ import ( # pyre-ignore[21]
evaluate,
GenericGPTQRunner,
get_task_dict,
InputRecorder,
lm_eval,
MultiInput,
)
print("after import")


class GPTQQuantizer(Quantizer):
Expand Down Expand Up @@ -633,4 +634,4 @@ def _convert_for_runtime(self, model):
)
return model
else:
print("lm_eval not available, skip defining GPTQQuantizer")
print("lm_eval not available, skip importing GPTQQuantizer")
13 changes: 13 additions & 0 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"compute_error",
"_apply_logging_hook",
"get_model_size_in_bytes",
"is_lm_eval_available",
]


Expand Down Expand Up @@ -86,3 +87,15 @@ def get_model_size_in_bytes(model):
for b in model.buffers():
s += b.nelement() * b.element_size()
return s


def is_lm_eval_available():
lm_eval_available = False
try:
import lm_eval # pyre-ignore[21] # noqa: F401

lm_eval_available = True
except:
lm_eval_available = False
print("func: is lm eval available:", lm_eval_available)
return lm_eval_available