# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2023-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=redefined-builtin
""" Affine quantizers """
import abc
from itertools import chain, repeat
from typing import Optional, List, Dict, Tuple, overload
import contextlib
import functools
import torch
from torch import nn
from aimet_torch.v2.utils import patch_attr, _is_expandable, StatisticsNotFoundError, docstring
from aimet_torch.v2.quantization.encoding_analyzer import EncodingAnalyzer, MinMaxEncodingAnalyzer, _flag_extreme_min_max
from aimet_torch.v2.quantization.affine import AffineEncoding
from aimet_torch.v2.quantization.tensor import QuantizedTensor, DequantizedTensor
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine.backends import quantize, quantize_dequantize, torch_builtins, _derive_qmin_qmax
from aimet_torch.v2.utils import ste_round
from ._utils import _GridMixin, _register_signature # pylint: disable=import-error
__all__ = ['AffineQuantizerBase', 'MinMaxQuantizer', 'Quantize', 'QuantizeDequantize',
'GroupedBlockQuantizeDequantize']
class AffineQuantizerBase(QuantizerBase, _GridMixin):
"""
Base class for linear quantization modules.
Args:
shape (tuple): Shape of the quantization parameters
bitwidth (int): Quantization bitwidth
symmetric (bool): If True, performs symmetric quantization;
otherwise, performs asymmetric quantization
encoding_analyzer (EncodingAnalyzer, optional): Encoding analyzer for calibrating quantization encodings
(default: absolute min-max encoding analyzer)
"""
_init_signatures = []
@overload
@_register_signature(_init_signatures)
def __init__(self, shape, qmin: int, qmax: int, symmetric: bool, encoding_analyzer: EncodingAnalyzer = None,
block_size: Optional[Tuple[int, ...]] = None):
...
@overload
@_register_signature(_init_signatures)
def __init__(self, shape, bitwidth: int, symmetric: bool, encoding_analyzer: EncodingAnalyzer = None,
block_size: Optional[Tuple[int, ...]] = None):
...
def __init__(self, shape, *args, **kwargs):
super().__init__()
if isinstance(shape, int):
shape = (shape,)
self.shape = tuple(shape)
full_args = (shape, *args)
# Pad positional args with None's such that len(args) == 5
args = tuple(chain(args, repeat(None, 5 - len(args))))
arg0 = kwargs.pop('qmin', kwargs.pop('bitwidth', args[0]))
arg1 = kwargs.pop('qmax', args[1])
if arg1 is not None and not isinstance(arg1, bool):
# (arg0, arg1, arg2) == (qmin, qmax, symmetric)
qmin, qmax = arg0, arg1
symmetric = kwargs.pop('symmetric', args[2])
if (qmin is None) or (qmax is None) or (symmetric is None):
raise self._arg_parsing_error(full_args, kwargs)
encoding_analyzer = kwargs.pop('encoding_analyzer', args[3])
block_size = kwargs.pop('block_size', args[4])
else:
# (arg0, arg1) == (bitwidth, symmetric)
bitwidth = arg0
symmetric = kwargs.pop('symmetric', args[1])
if (bitwidth is None) or (symmetric is None):
raise self._arg_parsing_error(full_args, kwargs)
# We support two quantization modes: (unsigned) asymmetric and signed-symmetric
qmin, qmax = _derive_qmin_qmax(bitwidth=bitwidth, signed=symmetric)
encoding_analyzer = kwargs.pop('encoding_analyzer', args[2])
block_size = kwargs.pop('block_size', args[3])
assert qmin is not None
assert qmax is not None
if kwargs:
cls = type(self).__qualname__
unexpected_keys = ', '.join(kwargs.keys())
raise TypeError(f"{cls}.__init__ got unexpected keyword argument: {unexpected_keys}")
if qmin >= qmax:
raise ValueError(f"qmax should be strictly larger than qmin. Got qmax={qmax}, qmin={qmin}")
self.qmin = qmin
self.qmax = qmax
self._symmetric = symmetric
self.block_size = block_size
self.encoding_analyzer = encoding_analyzer or \
MinMaxEncodingAnalyzer(torch_builtins.get_encoding_shape_with_blocks(self.shape,
self.block_size))
if self.block_size is None and not _is_expandable(self.encoding_analyzer.observer.shape, self.shape):
raise RuntimeError(f'Encoding analyzer of shape {self.encoding_analyzer.observer.shape} '
f'is incompatible with quantizer of shape {self.shape}.')
@abc.abstractmethod
def get_min(self, dtype=None) -> torch.Tensor:
"""
Compute quantization min to be used for forward pass.
Return None f the quantizer is not initialized yet.
Args:
dtype (torch.dtype): dtype of the computed min
Returns:
Quantization min
"""
@abc.abstractmethod
def get_max(self, dtype=None) -> torch.Tensor:
"""
Compute quantization max to be used for forward pass.
Return None f the quantizer is not initialized yet.
Args:
dtype (torch.dtype): dtype of the computed max
Returns:
Quantization max
"""
@abc.abstractmethod
def get_scale(self, dtype=None) -> torch.Tensor:
"""
Compute quantization scale to be used for forward pass.
Return None f the quantizer is not initialized yet.
Args:
dtype (torch.dtype): dtype of the computed scale
Returns:
Quantization scale
"""
@abc.abstractmethod
def get_offset(self, dtype=None) -> torch.Tensor:
"""
Compute quantization offset to be used for forward pass.
Return None f the quantizer is not initialized yet.
Args:
dtype (torch.dtype): dtype of the computed offset
Returns:
Quantization offset
"""
@abc.abstractmethod
def set_range(self, min: torch.Tensor, max: torch.Tensor):
"""
Set quantization parameters to the given min-max range
"""
def get_encoding(self) -> Optional[AffineEncoding]:
"""
Return the quantizer's encodings as an AffineEncoding object
"""
if self.is_initialized():
return AffineEncoding(self.get_scale(dtype=torch.float32),
self.get_offset(dtype=torch.float32),
self.qmin, self.qmax, self._symmetric, self.block_size)
return None
@torch.no_grad()
def get_legacy_encodings(self) -> Optional[List[Dict]]:
"""
Returns a list of encodings, each represented as a List of Dicts
"""
# pylint: disable=redefined-builtin, protected-access
if not self.is_initialized():
return None
return self.get_encoding()._to_legacy_format()
@torch.no_grad()
def set_legacy_encodings(self, encodings: List[Dict]):
"""
Set encodings represented in the same format as the output of get_legacy_encodings as below:
[
{'min': float, 'max': float, 'scale': float, 'offset': float,
'bitwidth': int, 'dtype': str, 'is_symmetric': str},
{'min': float, 'max': float, 'scale': float, 'offset': float,
'bitwidth': int, 'dtype': str, 'is_symmetric': str},
...
]
"""
def str_to_bool(s: str):
s = s.lower()
if s == "false":
return False
if s == "true":
return True
raise ValueError
bitwidth = encodings[0]['bitwidth']
symmetric = str_to_bool(encodings[0]['is_symmetric'])
# We support two quantization modes: (unsigned) asymmetric and signed-symmetric
self.qmin, self.qmax = _derive_qmin_qmax(bitwidth=bitwidth, signed=symmetric)
self.symmetric = symmetric
# Note: We can only accurately infer signed-ness in the symmetric case, but AIMET uses unsigned for asymmetric
min_ = torch.tensor([e['min'] for e in encodings]).view(self.shape)
max_ = torch.tensor([e['max'] for e in encodings]).view(self.shape)
self.set_range(min_, max_)
def extra_repr(self) -> str:
return f'shape={self.shape}, qmin={self.qmin}, qmax={self.qmax}, symmetric={self.symmetric}'
@property
def symmetric(self) -> bool:
"""
Indicates whether this quantizer uses symmetric quantization
"""
return self._symmetric
@symmetric.setter
def symmetric(self, symmetric: bool):
"""
Set the quantizer symmetry
:param symmetric: If True, use symmetric encodings. Else, use asymmetric encodings
"""
self._symmetric = symmetric
@property
@docstring(_GridMixin._get_bitwidth.__doc__)
def bitwidth(self) -> int: # pylint: disable=missing-function-docstring
return self._get_bitwidth()
@bitwidth.setter
def bitwidth(self, bitwidth: int):
self._set_bitwidth(bitwidth)
@property
@docstring(_GridMixin._get_signed.__doc__)
def signed(self) -> bool: # pylint: disable=missing-function-docstring
return self._get_signed()
@signed.setter
def signed(self, signed: bool):
self._set_signed(signed)
class MinMaxQuantizer(AffineQuantizerBase): # pylint: disable=abstract-method
"""
Affine quantizer with min-max as trainable parameters
"""
min: torch.nn.Parameter
max: torch.nn.Parameter
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_quantization_parameter('min', nn.Parameter(-torch.ones(self.shape)))
self.register_quantization_parameter('max', nn.Parameter(torch.ones(self.shape)))
@contextlib.contextmanager
def compute_encodings(self):
"""
Observe inputs and update quantization parameters based on the input statistics.
During ``compute_encodings`` is enabled, the quantizer forward pass performs
dynamic quantization using the batch statistics.
"""
if not self._allow_overwrite:
yield
return
original_forward = self.forward
@functools.wraps(original_forward)
def forward_wrapper(input):
input = input.as_subclass(torch.Tensor)
expanded_input = torch_builtins.reshape_tensor_for_blocks(input, self.shape, self.block_size)
batch_statistics = self.encoding_analyzer.update_stats(expanded_input)
num_steps = self.qmax - self.qmin
dynamic_min, dynamic_max =\
self.encoding_analyzer.compute_encodings_from_stats(batch_statistics,
num_steps,
self.symmetric)
if self.block_size is not None:
dynamic_min = dynamic_min.view(self.min.shape)
dynamic_max = dynamic_max.view(self.max.shape)
dynamic_min = dynamic_min.to(dtype=self.min.dtype,
device=self.min.device).expand_as(self.min)
dynamic_max = dynamic_max.to(dtype=self.max.dtype,
device=self.max.device).expand_as(self.max)
with patch_attr(self, 'min', dynamic_min),\
patch_attr(self, 'max', dynamic_max):
return original_forward(input)
self.encoding_analyzer.reset_stats()
try:
with patch_attr(self, 'forward', forward_wrapper):
yield
except: # pylint: disable=try-except-raise
raise
else:
try:
num_steps = self.qmax - self.qmin
enc_min, enc_max = self.encoding_analyzer.compute_encodings(num_steps, self.symmetric)
if self.block_size is not None:
enc_min = enc_min.view(self.min.shape)
enc_max = enc_max.view(self.max.shape)
_flag_extreme_min_max(enc_min, enc_max)
except StatisticsNotFoundError:
return
if enc_min is None or enc_max is None:
return
self.set_range(enc_min, enc_max)
def get_min(self, dtype=None) -> Optional[torch.Tensor]:
"""
Compute quantization min to be used for forward pass.
NOTE: self.min may not be equal to self.get_min().
self.get_min() returns slightly recalibrated version of self.min.
:param dtype: dtype of the computed min. Use of self.min.dtype by default.
:return: Quantization min
"""
if not self.is_initialized():
return None
return self.get_scale(dtype) * (self.get_offset(dtype) + self.qmin)
def get_max(self, dtype=None) -> Optional[torch.Tensor]:
"""
Compute quantization max to be used for forward pass.
NOTE: self.max may not be equal to self.get_max()
self.get_max() returns slightly recalibrated version of self.max.
:param dtype: dtype of the computed max. Use of self.min.dtype by default.
:return: Quantization max
"""
if not self.is_initialized():
return None
return self.get_scale(dtype) * (self.get_offset(dtype) + self.qmax)
def get_scale(self, dtype=None) -> Optional[torch.Tensor]:
"""
Compute quantization scale to be used for forward pass.
:param dtype: dtype of the computed scale. Use of self.min.dtype by default.
:return: Quantization scale
"""
if not self.is_initialized():
return None
dtype = dtype or torch.float32
num_steps = self.qmax - self.qmin
scale = (self.max.to(dtype) - self.min.to(dtype)) / num_steps
return scale.to(dtype)
def get_offset(self, dtype=None) -> Optional[torch.Tensor]:
"""
Compute quantization offset to be used for forward pass.
:param dtype: dtype of the computed offset. Use of self.min.dtype by default.
:return: Quantization offset
"""
if not self.is_initialized():
return None
dtype = dtype or torch.float32
if self.symmetric:
offset = torch.zeros_like(self.min, requires_grad=False, dtype=dtype)
else:
offset = ste_round(self.min.to(dtype) / self.get_scale(dtype)) - self.qmin
return offset.to(dtype)
def set_range(self, min: torch.Tensor, max: torch.Tensor):
"""
Set quantization parameters to the given min-max range
"""
with torch.no_grad():
self.min.copy_(min)
self.max.copy_(max)
[docs]class Quantize(MinMaxQuantizer):
r"""Applies quantization to the input.
Precisely,
.. math::
out = clamp\left(\left\lceil\frac{input}{scale}\right\rfloor - offset, qmin, qmax\right)
where :math:`scale` and :math:`offset` are derived from learnable parameters
:math:`\theta_{min}` and :math:`\theta_{max}`.
If block size :math:`B = \begin{pmatrix} B_0 & B_1 & \cdots & B_{D-1} \end{pmatrix}` is specified,
this equation will be further generalized as
.. math::
out_{j_0 \cdots j_{D-1}} & = clamp\left(
\left\lceil\frac{input_{j_0 \cdots j_{D-1}}}{scale_{i_0 \cdots i_{D-1}}}\right\rfloor
- offset_{i_0 \cdots i_{D-1}}, qmin, qmax\right)\\
\text{where} \quad \forall_{0 \leq d < D} \quad i_d = \left\lfloor \frac{j_d}{B_d} \right\rfloor
Args:
shape (tuple): Shape of the quantization parameters
bitwidth (int): Quantization bitwidth
symmetric (bool): If True, performs symmetric quantization;
otherwise, performs asymmetric quantization
encoding_analyzer (EncodingAnalyzer, optional): Encoding analyzer for calibrating quantization encodings
(default: absolute min-max encoding analyzer)
block_size (Tuple[int, ...], optional): Block size
:ivar Tensor min: :math:`\theta_{min}` from which scale and offset will be derived.
:ivar Tensor max: :math:`\theta_{max}` from which scale and offset will be derived.
.. note::
:class:`Quantize` cannot run :meth:`forward` until :attr:`min` and :attr:`max` are properly initialized,
which can be done based on input statistics using :meth:`compute_encodings` or
by manually assigning a new value to :attr:`min` and :attr:`max`.
See the examples below.
Examples:
>>> import aimet_torch.v2.quantization as Q
>>> input = torch.randn(5, 10)
>>> q = Q.affine.Quantize(shape=(5, 1), bitwidth=8, symmetric=False, block_size=(1, 5))
>>> q.is_initialized()
False
>>> with q.compute_encodings():
... _ = q(input)
...
>>> q.is_initialized()
True
>>> q(input)
QuantizedTensor([[129., 64., 255., 122., 0., 192., 106., 94., 255., 0.],
[ 0., 145., 181., 255., 144., 255., 194., 0., 74., 86.],
[122., 0., 255., 150., 33., 103., 103., 0., 37., 255.],
[255., 111., 237., 218., 0., 49., 155., 255., 0., 179.],
[ 0., 66., 255., 89., 110., 17., 36., 83., 255., 0.]],
grad_fn=<AliasBackward0>)
>>> import aimet_torch.v2.quantization as Q
>>> input = torch.randn(5, 10)
>>> q = Q.affine.Quantize(shape=(5, 1), bitwidth=8, symmetric=False, block_size=(1, 5))
>>> q.is_initialized()
False
>>> q.min = torch.nn.Parameter(-torch.ones_like(q.min))
>>> q.max = torch.nn.Parameter(torch.ones_like(q.max))
>>> q.is_initialized()
True
>>> q(input)
QuantizedTensor([[187., 186., 131., 0., 203., 64., 80., 0., 143., 152.],
[ 16., 0., 255., 0., 0., 150., 0., 255., 32., 255.],
[255., 226., 0., 255., 55., 172., 0., 255., 145., 255.],
[207., 146., 216., 238., 0., 0., 141., 178., 255., 188.],
[ 63., 59., 19., 162., 30., 255., 109., 255., 0., 255.]],
grad_fn=<AliasBackward0>)
"""
# NOTE: Deepspeed has a bug where it will inadvertently patch __init__ method permanently
# unless each leaf class explicitly defines its own __init__ separately.
# As a temporary workaround, we define __init__ to avoid triggering this bug.
# pylint: disable=useless-super-delegation
def __init__(self, shape, *args, **kwargs):
super().__init__(shape, *args, **kwargs)
[docs] def forward(self, input: torch.Tensor) -> QuantizedTensor:
"""Quantizes the input tensor
Args:
input (torch.Tensor): Input to quantize
Returns:
Quantized output
"""
if not self.is_initialized():
raise RuntimeError(
'Failed to run Quantize since quantization parameters are not initialized.'
' Please initialize the quantization parameters using `compute_encodings()`.'
)
encoding = self.get_encoding()
# Subclasses of torch.Tensor with custom __torch_function__ (in our case, QuantizedTensorBase)
# is known to introduce substantial CPU overhead.
# Cast types of the inputs to plain torch.Tensor for faster execution.
input = input.as_subclass(torch.Tensor)
output = quantize(input,
encoding.scale,
encoding.offset,
encoding.qmin,
encoding.qmax,
block_size=self.block_size)
output = output.as_subclass(QuantizedTensor)
output.encoding = encoding
return output
[docs]class QuantizeDequantize(MinMaxQuantizer):
r"""Applies fake-quantization by quantizing and dequantizing the input.
Precisely,
.. math::
out = (\overline{input} + offset) * scale
where
.. math::
\overline{input} = clamp\left(\left\lceil\frac{input}{scale}\right\rfloor - offset, qmin, qmax\right)
and :math:`scale` and :math:`offset` are derived from learnable parameters
:math:`\theta_{min}` and :math:`\theta_{max}`.
If block size :math:`B = \begin{pmatrix} B_0 & B_1 & \cdots & B_{D-1} \end{pmatrix}` is specified,
this equation will be further generalized as
.. math::
out_{j_0 \cdots j_{D-1}} &= (\overline{input}_{j_0 \cdots j_{D-1}} + offset_{i_0 \cdots i_{D-1}}) * scale_{i_0 \cdots i_{D-1}}\\
\overline{input}_{j_0 \cdots j_{D-1}} &= clamp\left(
\left\lceil\frac{input_{j_0 \cdots j_{D-1}}}{scale_{i_0 \cdots i_{D-1}}}\right\rfloor
- offset_{i_0 \cdots i_{D-1}}, qmin, qmax\right)\\
\text{where} \quad \forall_{0 \leq d < D} \quad i_d = \left\lfloor \frac{j_d}{B_d} \right\rfloor
Args:
shape (tuple): Shape of the quantization parameters
bitwidth (int): Quantization bitwidth
symmetric (bool): If True, performs symmetric quantization;
otherwise, performs asymmetric quantization
encoding_analyzer (EncodingAnalyzer, optional): Encoding analyzer for calibrating quantization encodings
(default: absolute min-max encoding analyzer)
block_size (Tuple[int, ...], optional): Block size
:ivar Tensor min: :math:`\theta_{min}` from which scale and offset will be derived.
:ivar Tensor max: :math:`\theta_{max}` from which scale and offset will be derived.
.. note::
:class:`QuantizeDequantize` cannot run :meth:`forward` until :attr:`min` and :attr:`max` are properly initialized,
which can be done based on input statistics using :meth:`compute_encodings` or
by manually assigning a new value to :attr:`min` and :attr:`max`.
See the examples below.
Examples:
>>> import aimet_torch.v2.quantization as Q
>>> input = torch.randn(5, 10)
>>> qdq = Q.affine.QuantizeDequantize(shape=(5, 2), bitwidth=8, symmetric=False, block_size=(1, 5))
>>> qdq.is_initialized()
False
>>> with qdq.compute_encodings():
... _ = qdq(input)
...
>>> qdq.is_initialized()
True
>>> qdq(input)
DequantizedTensor([[-0.2771, 0.3038, 1.0819, 0.9700, 0.9487, -0.1307,
-1.7894, -0.1709, -0.2212, 0.7741],
[-1.0295, -1.2265, -1.0295, 1.0564, 0.6177, -1.0386,
-0.0176, -2.6054, 1.8836, -0.1232],
[-0.8229, 0.5540, 0.3992, -0.2363, 1.2546, -1.0036,
0.2355, 0.1741, 1.6079, 0.6247],
[-1.0115, 1.2458, 0.9157, -1.4694, -0.0639, -0.2568,
0.0680, 1.6695, 0.7932, -0.1889],
[ 0.0158, 0.5695, 0.5220, 0.1977, -1.4475, -0.0424,
-1.1128, -0.8796, -0.1060, 1.5897]],
grad_fn=<AliasBackward0>)
>>> import aimet_torch.v2.quantization as Q
>>> input = torch.randn(5, 10)
>>> qdq = Q.affine.QuantizeDequantize(shape=(5, 2), bitwidth=8, symmetric=False, block_size=(1, 5))
>>> qdq.is_initialized()
False
>>> qdq.min = torch.nn.Parameter(-torch.ones_like(qdq.min))
>>> qdq.max = torch.nn.Parameter(torch.ones_like(qdq.max))
>>> qdq.is_initialized()
True
>>> qdq(input)
DequantizedTensor([[-0.6196, -0.9961, 0.0549, -0.6431, 1.0039, -0.8706,
1.0039, 0.4706, -0.2353, 0.8078],
[ 0.3451, -0.1176, -0.9961, -0.4549, -0.0549, -0.0471,
-0.5255, -0.2353, 1.0039, -0.9961],
[-0.4157, 0.0784, 0.5333, 0.1647, -0.9961, -0.9961,
-0.2118, -0.2196, 0.9176, 0.9490],
[ 1.0039, -0.7765, 0.4784, -0.8706, 1.0039, 0.6039,
-0.4157, -0.2118, -0.9961, 0.3137],
[ 1.0039, 0.3216, -0.2353, -0.7765, -0.9961, 0.8000,
1.0039, 0.4157, 0.4392, 0.4863]],
grad_fn=<AliasBackward0>)
"""
# NOTE: Deepspeed has a bug where it will inadvertently patch __init__ method permanently
# unless each leaf class explicitly defines its own __init__ separately.
# As a temporary workaround, we define __init__ to avoid triggering this bug.
# pylint: disable=useless-super-delegation
def __init__(self, shape, *args, **kwargs):
super().__init__(shape, *args, **kwargs)
[docs] def forward(self, input: torch.Tensor) -> DequantizedTensor:
"""Quantizes and dequantizes the input tensor
Args:
input (torch.Tensor): Input to quantize and dequantize
Returns:
Quantize-dequantized output
"""
if not self.is_initialized():
raise RuntimeError(
'Failed to run QuantizeDequantize since quantization parameters are not initialized.'
' Please initialize the quantization parameters using `compute_encodings()`.'
)
encoding = self.get_encoding()
# Subclasses of torch.Tensor with custom __torch_function__ (in our case, QuantizedTensorBase)
# is known to introduce substantial CPU overhead.
# Cast types of the inputs to plain torch.Tensor for faster execution.
input = input.as_subclass(torch.Tensor)
output = quantize_dequantize(input,
encoding.scale,
encoding.offset,
encoding.qmin,
encoding.qmax,
block_size=self.block_size)
output = output.as_subclass(DequantizedTensor)
output.encoding = encoding
return output
class GroupedBlockQuantizeDequantize(QuantizeDequantize): # pylint: disable=too-many-ancestors
""" Class for performing Grouped Block Quantize Dequantize """
def __init__(self, shape, bitwidth: int, symmetric: bool, decompressed_bw: int,
encoding_analyzer: EncodingAnalyzer = None, block_size: Optional[Tuple[int, ...]] = None,
block_grouping: Optional[Tuple[int, ...]] = None):
"""
Grouped Block Quantize Dequantize constructor.
:param shape: Shape of the quantization parameters
:type shape: tuple
:param bitwidth: Quantization bitwidth
:type bitwidth: int
:param symmetric: If True, performs symmetric quantization;
otherwise, performs asymmetric quantization
:type symmetric: bool
:param decompressed_bw: Bitwidth used for decompression
:type decompressed_bw: int
:param encoding_analyzer: Encoding analyzer for calibrating quantization encodings
(default: absolute min-max encoding analyzer)
:type encoding_analyzer: EncodingAnalyzer, optional
:param block_size: Block size per dimension.
:type block_size: Tuple
:param block_grouping: Block grouping per dimension. If provided, every set of block_group scales will be
grouped together, and the maximum scale for all blocks in the group will be used to find
the scale in the decompressed_grid to be shared by all blocks in the group.
If no block_grouping is provided, default behavior uses a block group of 1 for all dims,
equivalent to Blockwise Quantization.
A value of -1 for a block group for a dimension is equivalent to grouping all blocks in
the dimension in one group. This is also equivalent to a block group value equal to the
number of blocks for that dimension.
:type block_grouping: Tuple
"""
super().__init__(shape, bitwidth, symmetric, encoding_analyzer, block_size)
self.decompressed_bw = decompressed_bw
self.block_grouping = block_grouping
if self.block_grouping is None:
# Default to BQ behavior with 1 for all block grouping dims if not provided
self.block_grouping = tuple(1 for _ in enumerate(self.shape))
if block_grouping is not None:
if len(block_grouping) != len(shape):
raise RuntimeError(f'Length of block grouping {block_grouping} must equal length of shape {shape}.')
for idx, block_group in enumerate(block_grouping):
if block_group != -1 and shape[idx] % block_group != 0:
raise RuntimeError(f'Quantizer shape dimensions must divide evenly with corresponding block '
f'grouping values for shapes {shape} and block grouping {block_grouping}.')
if self.decompressed_bw < self.bitwidth:
raise RuntimeError(f'Decompressed bitwidth {decompressed_bw} cannot be smaller than self.bitwidth '
f'{bitwidth}')
if not symmetric:
raise RuntimeError('GroupedBlockQuantizeDequantize only supports symmetric quantization.')
def get_scale(self, dtype=None) -> torch.Tensor:
"""
Compute quantization scale to be used for forward pass.
Overrides QuantizeDequantize self.get_scale() to apply the grouped block algorithm for calculating modified
scales.
:param dtype: dtype of the computed scale. Use of self.min.dtype by default.
:return: Updated scale
"""
orig_scale = super().get_scale(dtype)
orig_scale_shape = orig_scale.shape
reshaped_scale = orig_scale.view(self.get_expanded_scale_shape())
max_scale = torch.amax(reshaped_scale, list(range(1, len(orig_scale_shape) * 2, 2)), keepdim=True)
per_channel_scale = max_scale / 2 ** (self.decompressed_bw - self.bitwidth)
updated_scale = quantize_dequantize(reshaped_scale,
scale=per_channel_scale,
offset=torch.zeros_like(per_channel_scale),
qmin=1,
qmax=2 ** (self.decompressed_bw - self.bitwidth))
return updated_scale.view(orig_scale_shape)
def get_expanded_scale_shape(self) -> Tuple[int, ...]:
"""
Get expanded scale shape which breaks each scale dimension into a pair of dimensions with sizes
(original_shape / block_grouping, block_grouping).
:return: Expanded scale shape
"""
expanded_shape = []
for idx, block_group in enumerate(self.block_grouping):
# Block group of -1 is equivalent to grouping all blocks together
if block_group == -1:
expanded_shape.append(1)
expanded_shape.append(self.shape[idx])
else:
expanded_shape.append(self.shape[idx] // block_group)
expanded_shape.append(block_group)
return expanded_shape
def get_per_channel_scale(self, dtype=None) -> torch.Tensor:
"""
Get per channel scale.
:return: Per channel scale
"""
orig_scale = super().get_scale(dtype)
orig_scale_shape = orig_scale.shape
reshaped_scale = orig_scale.view(self.get_expanded_scale_shape())
max_scale = torch.amax(reshaped_scale, list(range(1, len(orig_scale_shape) * 2, 2)), keepdim=True)
per_channel_scale = max_scale / 2 ** (self.decompressed_bw - self.bitwidth)
return per_channel_scale
def get_per_block_integer_scale(self) -> torch.Tensor:
"""
Get per block integer scale.
:return: Per block integer scale
"""
per_channel_scale = self.get_per_channel_scale()
expanded_scale = self.get_scale().view(self.get_expanded_scale_shape())
integer_scale = torch.round(expanded_scale / per_channel_scale).int().view(self.get_scale().shape)
return integer_scale