diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index e2bfba989..3a7d732f9 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -3,39 +3,105 @@ # SPDX-License-Identifier: BSD-3-Clause """ + +from copy import deepcopy + +from accelerate.utils.operations import send_to_device import torch from tqdm import tqdm from brevitas.graph.gpfq import gpfq_mode from brevitas.graph.gptq import gptq_mode +from brevitas.graph.gpxq import StopFwdException +from brevitas.utils.python_utils import recurse_getattr + +@torch.no_grad() +def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs): + cache_state = model.config.use_cache + model.config.use_cache = False + blocks = recurse_getattr(model, block_name) + first_block = blocks[0] + cached_args, cached_kwargs = [], [] + + # Intercept input to first block + def intercept_input(module, args, kwargs): + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + cached_args.append(args) + cached_kwargs.append(kwargs) + raise StopFwdException + + # Intercept output from block N-1 to set it as input to block N + def intercept_output(module, args, kwargs, output): + if isinstance(output, tuple): + output = output[0] + output = send_to_device(output, 'cpu') + cached_args.append((output,)) + raise StopFwdException + + # Collect input to first block + hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True) + for inps in dataloader: + try: + model(**inps) + except StopFwdException: + pass + hook.remove() + + # Iterate through all the blocks + for index, block in enumerate(tqdm(blocks)): + with context_manager_func(block, **context_manager_kwargs) as gpxq: + for _ in tqdm(range(gpxq.num_layers)): + for args, kwargs in zip(cached_args, cached_kwargs): + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + gpxq.update() + if index < len(blocks) - 1: + # Once the block is done, we need to update the input to the next block + past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) + cached_args = [] + hook = block.register_forward_hook(intercept_output, with_kwargs=True) + for args, kwargs in zip(past_cached_args, past_cached_kwargs): + try: + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + except StopFwdException: + pass + hook.remove() + # Restore cache state + model.config.use_cache = cache_state @torch.no_grad() -def apply_gptq( - model, - dataloader, - act_order=True, - group_of_parallel_layers=None, - use_quant_activations=True, - create_weight_orig=False): - with gptq_mode(model, - act_order=act_order, - group_of_parallel_layers=group_of_parallel_layers, - use_quant_activations=use_quant_activations, - create_weight_orig=create_weight_orig) as gptq: - gptq_model = gptq.model - for _ in tqdm(range(gptq.num_layers)): - for inps in dataloader: - gptq_model(**inps) - gptq.update() +def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): + if block_name is not None: + context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False, 'use_quant_activations': False} + block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs) + else: + with gptq_mode(model, + use_quant_activations=False, + group_of_parallel_layers=group_of_parallel_layers, + act_order=act_order, + create_weight_orig=False) as gptq: + gptq_model = gptq.model + for _ in tqdm(range(gptq.num_layers)): + for inps in dataloader: + gptq_model(**inps) + gptq.update() @torch.no_grad() -def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None): - with gpfq_mode(model, act_order=act_order, - group_of_parallel_layers=group_of_parallel_layers) as gpfq: - gpfq_model = gpfq.model - for _ in tqdm(range(gpfq.num_layers)): - for inps in dataloader: - gpfq_model(**inps) - gpfq.update() +def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None): + if block_name is not None: + context_manager_kwargs = {'act_order': act_order, 'group_of_parallel_layers': group_of_parallel_layers, 'create_weight_orig': False} + block_optimization(model, dataloader, block_name, gpfq_mode, context_manager_kwargs) + else: + with gpfq_mode(model, act_order=act_order, + group_of_parallel_layers=group_of_parallel_layers) as gpfq: + gpfq_model = gpfq.model + for _ in tqdm(range(gpfq.num_layers)): + for inps in dataloader: + gpfq_model(**inps) + gpfq.update() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index c33de54c8..f7ca0810d 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -118,7 +118,6 @@ def validate(args): def main(args): validate(args) set_seed(args.seed) - if args.export_prefix is None: args.export_prefix = f"{args.model.replace('/', '--')}" @@ -340,6 +339,13 @@ def parse_args(args): choices=['wikitext2', 'c4'], default='wikitext2', help='Dataset to use for quantization (default: %(default)s)') + parser.add_argument( + '--gptq-block-name', + type=str, + default=None, + help= + 'Block name for faster GPTQ optimization. It works only if FX is not needed (default: %(default)s)' + ) parser.add_argument( '--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') parser.add_argument(