Skip to content

Commit

Permalink
[Quant][PT2E] Enable X86InductorQuantizer single quantizable op(maxpo…
Browse files Browse the repository at this point in the history
…ol2d)

ghstack-source-id: 0a0d7a11ebfb995a2d840d82667a193e92da62ee
Pull Request resolved: pytorch#105639
  • Loading branch information
leslie-fang-intel committed Aug 21, 2023
1 parent d4fcf68 commit 315e8eb
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 16 deletions.
81 changes: 80 additions & 1 deletion test/quantization/pt2e/test_x86inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from enum import Enum
import itertools
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq

import operator
from torch.ao.quantization import ObserverBase

class Conv2DType(Enum):
left = 1
Expand Down Expand Up @@ -127,6 +128,18 @@ def forward(self, x):
else:
return self.relu2(self.conv(x) + self.conv2(x))

class Conv2dMaxpoolPowModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
self.pool = nn.MaxPool2d(1, 1)

def forward(self, x):
x = self.conv(x)
x = self.pool(x)
return torch.pow(x, 2)


class SerialsConv2dAddReLUModule(torch.nn.Module):
""" Serials of 2 Conv2d -> Add -> ReLU Pattern.
"""
Expand Down Expand Up @@ -171,10 +184,13 @@ def _test_quantizer(
*copy.deepcopy(example_inputs),
aten_graph=True,
)
export_model = copy.deepcopy(m)
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
prepare_model = copy.deepcopy(m)
m = convert_pt2e(m)
convert_model = copy.deepcopy(m)
pt2_quant_output = m(*example_inputs)
node_occurrence = {
ns.call_function(k): v for k, v in expected_node_occurrence.items()
Expand All @@ -185,6 +201,7 @@ def _test_quantizer(
self.checkGraphModuleNodes(
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
)
return export_model, prepare_model, convert_model

@skipIfNoDynamoSupport
class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
Expand Down Expand Up @@ -400,3 +417,65 @@ def test_conv2d_serials_binary_unary_with_quantizer_api(self):
node_occurrence,
node_list,
)

@skipIfNoX86
def test_maxpool2d_recipe(self):
r"""
Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow)
Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow.
"""
m = TestHelperModules.Conv2dMaxpoolPowModule().eval()
x = torch.rand(1, 2, 14, 14)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
)
example_inputs = (x,)
node_occurrence = {
# one for input and weight of the conv, two for input/output for the maxpool2d
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 1,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.convolution.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.max_pool2d_with_indices.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
_, prepare_model, _ = self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
)
# Check Maxpool2d has share observer at input and output
for node in prepare_model.graph.nodes:
if (
node.op == "call_function"
and node.target is torch.ops.aten.max_pool2d_with_indices.default
):
maxpool_node = node
input_obs_of_maxpool = getattr(
prepare_model, maxpool_node.args[0].target
)
elif node.op == "call_function" and node.target is operator.getitem:
output_obs_of_maxpool = getattr(
prepare_model, list(node.users)[0].target
)
elif (
node.op == "call_function"
and node.target is torch.ops.aten.convolution.default
):
conv_node = node
input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target)
self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase))
self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase))
self.assertTrue(isinstance(input_obs_of_conv, ObserverBase))
self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool)
self.assertTrue(input_obs_of_maxpool is not input_obs_of_conv)
Loading

0 comments on commit 315e8eb

Please sign in to comment.