Skip to content

Commit

Permalink
Block gptq
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 11, 2024
1 parent 4617f7b commit 58d6f15
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 26 deletions.
116 changes: 91 additions & 25 deletions src/brevitas_examples/llm/llm_quant/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,105 @@
# SPDX-License-Identifier: BSD-3-Clause
"""


from copy import deepcopy

from accelerate.utils.operations import send_to_device
import torch
from tqdm import tqdm

from brevitas.graph.gpfq import gpfq_mode
from brevitas.graph.gptq import gptq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.utils.python_utils import recurse_getattr

@torch.no_grad()
def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs):
cache_state = model.config.use_cache
model.config.use_cache = False
blocks = recurse_getattr(model, block_name)
first_block = blocks[0]
cached_args, cached_kwargs = [], []

# Intercept input to first block
def intercept_input(module, args, kwargs):
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
cached_args.append(args)
cached_kwargs.append(kwargs)
raise StopFwdException

# Intercept output from block N-1 to set it as input to block N
def intercept_output(module, args, kwargs, output):
if isinstance(output, tuple):
output = output[0]
output = send_to_device(output, 'cpu')
cached_args.append((output,))
raise StopFwdException

# Collect input to first block
hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True)
for inps in dataloader:
try:
model(**inps)
except StopFwdException:
pass
hook.remove()

# Iterate through all the blocks
for index, block in enumerate(tqdm(blocks)):
with context_manager_func(block, **context_manager_kwargs) as gpxq:
for _ in tqdm(range(gpxq.num_layers)):
for args, kwargs in zip(cached_args, cached_kwargs):
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
gpxq.update()

if index < len(blocks) - 1:
# Once the block is done, we need to update the input to the next block
past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs)
cached_args = []
hook = block.register_forward_hook(intercept_output, with_kwargs=True)
for args, kwargs in zip(past_cached_args, past_cached_kwargs):
try:
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
except StopFwdException:
pass
hook.remove()
# Restore cache state
model.config.use_cache = cache_state

@torch.no_grad()
def apply_gptq(
model,
dataloader,
act_order=True,
group_of_parallel_layers=None,
use_quant_activations=True,
create_weight_orig=False):
with gptq_mode(model,
act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers,
use_quant_activations=use_quant_activations,
create_weight_orig=create_weight_orig) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
gptq_model(**inps)
gptq.update()
def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
if block_name is not None:
context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False, 'use_quant_activations': False}
block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs)
else:
with gptq_mode(model,
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
gptq_model(**inps)
gptq.update()


@torch.no_grad()
def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None):
with gpfq_mode(model, act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers) as gpfq:
gpfq_model = gpfq.model
for _ in tqdm(range(gpfq.num_layers)):
for inps in dataloader:
gpfq_model(**inps)
gpfq.update()
def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
if block_name is not None:
context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False}
block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs)
else:
with gpfq_mode(model, act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers) as gpfq:
gpfq_model = gpfq.model
for _ in tqdm(range(gpfq.num_layers)):
for inps in dataloader:
gpfq_model(**inps)
gpfq.update()
8 changes: 7 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def validate(args):
def main(args):
validate(args)
set_seed(args.seed)

if args.export_prefix is None:
args.export_prefix = f"{args.model.replace('/', '--')}"

Expand Down Expand Up @@ -340,6 +339,13 @@ def parse_args(args):
choices=['wikitext2', 'c4'],
default='wikitext2',
help='Dataset to use for quantization (default: %(default)s)')
parser.add_argument(
'--gptq-block-name',
type=str,
default=None,
help=
'Block name for faster GPTQ optimization. It works only if FX is not needed (default: %(default)s)'
)
parser.add_argument(
'--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
parser.add_argument(
Expand Down

0 comments on commit 58d6f15

Please sign in to comment.