Skip to content

Commit

Permalink
Quantize vit_b_16 tutorial - Part 1 (pytorch#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Mar 22, 2024
1 parent 2871d74 commit 7ff1e42
Show file tree
Hide file tree
Showing 9 changed files with 4,214 additions and 8 deletions.
5 changes: 4 additions & 1 deletion torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from . import dtypes
from .quantization.quant_api import apply_dynamic_quant
from .quantization.quant_api import apply_weight_only_int8_quant

__all__ = [
"dtypes"
"dtypes",
"apply_dynamic_quant",
]
10 changes: 3 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,10 @@ def apply_weight_only_int8_quant(model, filter_fn=None):
def apply_dynamic_quant(model, filter_fn=None):
"""
Applies dynamic symmetric per-token activation and per-channel weight
quantization to all linear layers in the given model using
module swaps.
quantization to all linear layers by converting all linear weight
tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass.
"""
_replace_with_custom_fn_if_matches_filter(
model,
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod),
_is_linear if filter_fn is None else filter_fn,
)
change_linear_weights_to_int8_dqtensors(model, filter_fn)


def _get_subclass_inserter(cls, **kwargs):
Expand Down
Binary file added tutorials/quantize_vit/bfloat16.json.gz
Binary file not shown.
1,682 changes: 1,682 additions & 0 deletions tutorials/quantize_vit/bfloat16_code.py

Large diffs are not rendered by default.

Binary file added tutorials/quantize_vit/quant.json.gz
Binary file not shown.
2,413 changes: 2,413 additions & 0 deletions tutorials/quantize_vit/quant_code.py

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions tutorials/quantize_vit/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

# Run bfloat16 version
TORCH_LOGS='graph_breaks,recompiles' python run_vit_b.py

# Run dynamic quantized version
TORCH_LOGS='graph_breaks,recompiles' python run_vit_b_quant.py

# Store the output code for further inspection
echo "bfloat16 generated code lives in:"
TORCH_LOGS='output_code' python run_vit_b.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}'
echo "quantization generated code lives in:"
TORCH_LOGS='output_code' python run_vit_b_quant.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}'
46 changes: 46 additions & 0 deletions tutorials/quantize_vit/run_vit_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torchvision.models.vision_transformer as models

# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)

# Set the model to evaluation mode
model.eval().cuda().to(torch.bfloat16)

# Input tensor (batch_size, channels, height, width)
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')

model = torch.compile(model, mode='max-autotune')

def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result

# Must run with no_grad when optimizing for inference
with torch.no_grad():
# warmup
benchmark_model(model, 5, input_tensor)
# benchmark
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds")
# Create a trace
profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, input_tensor)
53 changes: 53 additions & 0 deletions tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torchao
import torchvision.models.vision_transformer as models

# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)

# Set the model to evaluation mode
model.eval().cuda().to(torch.bfloat16)

# Input tensor (batch_size, channels, height, width)
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')

## Quantization code - start
torchao.apply_dynamic_quant(model)
from torch._inductor import config as inductorconfig
inductorconfig.force_fuse_int_mm_with_mul = True
## Quantization code - end

model = torch.compile(model, mode='max-autotune')

def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result

# Must run with no_grad when optimizing for inference
with torch.no_grad():
# warmup
benchmark_model(model, 5, input_tensor)
# benchmark
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds")
# Create a trace
profiler_runner("quant.json.gz", benchmark_model, model, 5, input_tensor)

0 comments on commit 7ff1e42

Please sign in to comment.