# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=redefined-builtin
""" Float quantizers """
import contextlib
import functools
from typing import Dict, List, Optional
import math
import torch
from aimet_torch.v2.quantization.encoding_analyzer import EncodingAnalyzer, _flag_extreme_min_max
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.float import FloatEncoding
from aimet_torch.v2.quantization.tensor import DequantizedTensor
from aimet_torch.v2.utils import StatisticsNotFoundError, patch_attr
from aimet_torch.fp_quantization import fake_cast_to_ieee_float
from ._finfo import _finfo, _torch_dtype_to_finfo
__all__ = ['QuantizeDequantize', 'FloatQuantizeDequantize']
def _ieee_float_max_representable_value(exponent_bits, mantissa_bits):
exponent_max = 2 ** exponent_bits - 1
exponent_bias = exponent_max // 2
return (2 - 2**-mantissa_bits) * 2 ** (exponent_max - exponent_bias - 1)
[docs]
class FloatQuantizeDequantize(QuantizerBase): # pylint: disable=abstract-method
r"""
Simulates quantization by fake-casting the input
If dtype is provided, this is equivalent to
.. math::
out = x.to(dtype).to(x.dtype) \\
If the exponent and mantissa bits are provided, this is equivalent to
.. math::
out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale
where
.. math::
x_c &= clamp(x, -max, max) \\
bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\
scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\
The IEEE standard computes the maximum representable value by
.. math::
max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\
where
.. math::
exponent\_max = 2^{exponent} - 1 \\
Args:
exponent_bits (int): Number of exponent bits to simulate
mantissa_bits (int): Number of mantissa bits to simulate
dtype (torch.dtype): torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits.
encoding_analyzer (EncodingAnalyzer): If specified, the maximum value to represent will be determined dynamically based on the input statistics for finer precision.
Examples:
>>> import aimet_torch.v2.quantization as Q
>>> input = torch.tensor([[ 1.8998, -0.0947],[-1.0891, -0.1727]])
>>> qdq = Q.float.FloatQuantizeDequantize(mantissa_bits=7, exponent_bits=8)
>>> # Unlike AffineQuantizer, FloatQuantizer is initialized without calling compute_encodings()
>>> qdq.is_initialized()
True
>>> qdq.is_bfloat16()
True
>>> qdq.bitwidth
16
>>> qdq(input)
tensor([[ 1.8984, -0.0947], [-1.0859, -0.1729]])
>>> from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer
>>> encoding_analyzer = MinMaxEncodingAnalyzer(shape=[])
>>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float16, encoding_analyzer=encoding_analyzer)
>>> qdq.is_float16()
True
>>> qdq.bitwidth
16
>>> qdq(input)
tensor([[ 1.8994, -0.0947], [-1.0889, -0.1727]])
"""
maxval: Optional[torch.Tensor]
def __init__(self,
exponent_bits: Optional[int] = None,
mantissa_bits: Optional[int] = None,
finite: Optional[bool] = None,
unsigned_zero: Optional[bool] = None,
dtype: Optional[torch.dtype] = None,
encoding_analyzer: Optional[EncodingAnalyzer] = None):
super().__init__()
if dtype is None:
if exponent_bits is None or mantissa_bits is None:
raise ValueError('Neither "dtype" nor "exponent/mantissa_bits" was specified.')
if finite is None:
self.finite = False
if unsigned_zero is None:
self.unsigned_zero = False
if dtype is not None:
if exponent_bits is not None or mantissa_bits is not None or \
finite is not None or unsigned_zero is not None:
raise ValueError(
'Argument "dtype" is mutually exclusive with "exponent/mantissa_bits/finite/unsigned_zero".')
exponent_bits, mantissa_bits, finite, unsigned_zero = _finfo.from_torch_dtype(dtype)
self._finfo = _finfo(exponent_bits, mantissa_bits, finite, unsigned_zero)
self.encoding_analyzer = encoding_analyzer
if self.encoding_analyzer:
shape = self.encoding_analyzer.observer.shape
maxval = _ieee_float_max_representable_value(exponent_bits, mantissa_bits)
self.register_buffer('maxval', torch.full(shape, maxval))
else:
self.register_buffer('maxval', None)
self._assert_supported_dtype()
def _assert_supported_dtype(self):
if self._finfo.finite or self._finfo.unsigned_zero:
if self._finfo.to_torch_dtype() is None:
torch_special_builtin_dtypes = [
dtype for dtype in _torch_dtype_to_finfo
if dtype not in (torch.float16, torch.bfloat16)
]
msg = " ".join([
"finite/unsigned_zero floating point has limited support.",
f"Expected PyTorch built-in data types, such as {torch_special_builtin_dtypes};",
f"got '{self._finfo.to_str()}'",
])
raise RuntimeError(msg)
if self.maxval is not None:
msg = " ".join([
"finite/unsigned_zero floating point has limited support.",
"Expected 'maxval' to be None",
])
raise RuntimeError(msg)
@property
def exponent_bits(self):
""" Returns exponent bits """
return self._finfo.exponent_bits
@exponent_bits.setter
def exponent_bits(self, exponent_bits: int):
_, mantissa_bits, finite, unsigned_zero = self._finfo
self._finfo = _finfo(exponent_bits, mantissa_bits, finite, unsigned_zero)
@property
def mantissa_bits(self):
""" Returns mantissa bits """
return self._finfo.mantissa_bits
@mantissa_bits.setter
def mantissa_bits(self, mantissa_bits: int):
exponent_bits, _, finite, unsigned_zero = self._finfo
self._finfo = _finfo(exponent_bits, mantissa_bits, finite, unsigned_zero)
[docs]
def load_state_dict(self, state_dict, strict: bool = True):
if 'maxval' in state_dict:
if self.maxval is None:
del self.maxval
self.register_buffer('maxval', state_dict['maxval'])
elif self.maxval is not None:
del self.maxval
self.register_buffer('maxval', None)
ret = super().load_state_dict(state_dict, strict)
return ret
@property
def bitwidth(self):
"""
Returns bitwidth of the quantizer
"""
return self.exponent_bits + self.mantissa_bits + 1
[docs]
def is_float16(self):
"""
Returns true if current configuration simulates IEEE float16
"""
return self._finfo.is_float16()
[docs]
def is_bfloat16(self):
"""
Returns true if current configuration simulates bfloat16
"""
return self._finfo.is_bfloat16()
def get_legacy_encodings(self) -> Optional[List[Dict]]:
"""
:meta private:
"""
return [{'bitwidth': self.bitwidth, 'dtype': 'float'}]
def set_legacy_encodings(self, encodings: List[Dict]):
"""
:meta private:
Set encodings represented in the same format as the output of get_legacy_encodings as below:
[
{'bitwidth': int, 'dtype': str},
...
]
"""
if encodings[0]['bitwidth'] != 16:
raise RuntimeError(f"{self.__class__} can only import 16-bit legay encodings.")
self.exponent_bits = 5
self.mantissa_bits = 10
[docs]
def get_encodings(self) -> Optional[FloatEncoding]:
if self.is_initialized():
return FloatEncoding(self._finfo.mantissa_bits,
self._finfo.exponent_bits,
self._finfo.finite,
self._finfo.unsigned_zero,
self.maxval)
return None
[docs]
@classmethod
def from_encodings(cls, encodings: FloatEncoding) -> "FloatQuantizeDequantize":
if not isinstance(encodings, FloatEncoding):
raise TypeError(f"Expected {FloatEncoding}; got {type(encodings)}")
qtzr = cls(exponent_bits=encodings.exponent_bits,
mantissa_bits=encodings.mantissa_bits)
if encodings.maxval is not None:
qtzr.maxval.copy_(encodings.maxval)
return qtzr
[docs]
@contextlib.contextmanager
def compute_encodings(self):
"""
Observe inputs and update quantization parameters based on the input statistics.
During ``compute_encodings`` is enabled, the quantizer forward pass performs
dynamic quantization using the batch statistics.
"""
if not self.encoding_analyzer or not self._allow_overwrite:
yield
return
original_forward = self.forward
@functools.wraps(original_forward)
def forward_wrapper(input):
input = input.as_subclass(torch.Tensor)
batch_statistics = self.encoding_analyzer.update_stats(input)
num_steps = math.pow(2, self.bitwidth) - 1
dynamic_min, dynamic_max =\
self.encoding_analyzer.compute_encodings_from_stats(batch_statistics,
num_steps,
is_symmetric=False)
dynamic_absmax = torch.maximum(dynamic_min.abs(), dynamic_max.abs())
dynamic_absmax = dynamic_absmax.to(dtype=self.maxval.dtype,
device=self.maxval.device).expand_as(self.maxval)
with patch_attr(self, 'maxval', dynamic_absmax):
return original_forward(input)
self.encoding_analyzer.reset_stats()
try:
with patch_attr(self, 'forward', forward_wrapper):
yield
except: # pylint: disable=try-except-raise
raise
try:
num_steps = math.pow(2, self.bitwidth) - 1
min, max = self.encoding_analyzer.compute_encodings(num_steps,
is_symmetric=False)
_flag_extreme_min_max(min, max)
except StatisticsNotFoundError:
return
if min is None or max is None:
return
absmax = torch.maximum(min.abs(), max.abs()).expand_as(self.maxval)
with torch.no_grad():
self.maxval.copy_(absmax)
[docs]
def forward(self, input: torch.Tensor):
"""
:param input: Input to quantize and dequantize
:return: Quantize-dequantized output
"""
self._assert_supported_dtype()
if not self.is_initialized():
raise RuntimeError(
'Failed to run FloatQuantizeDequantize since quantization parameters are not initialized.'
' Please initialize the quantization parameters using `compute_encodings()`.'
)
encoding = self.get_encodings()
assert encoding is not None
maxval = encoding.maxval
exponent_bits = encoding.exponent_bits
mantissa_bits = encoding.mantissa_bits
target_torch_dtype = self._finfo.to_torch_dtype()
if maxval is None and target_torch_dtype is not None:
# Fast forward using type casting
orig_dtype = input.dtype
output = input.to(target_torch_dtype).to(orig_dtype)
else:
if maxval is None:
maxval = _ieee_float_max_representable_value(exponent_bits, mantissa_bits)
# Subclasses of torch.Tensor with custom __torch_function__ (in our case, QuantizedTensorBase)
# is known to introduce substantial CPU overhead.
# Cast types of the inputs to plain torch.Tensor for faster execution.
output = fake_cast_to_ieee_float(input.as_subclass(torch.Tensor),
maxval,
exponent_bits,
mantissa_bits)
output = output.as_subclass(DequantizedTensor)
output.encoding = encoding
return output
def extra_repr(self):
"""
:meta private:
"""
if self.maxval is None:
torch_dtype = self._finfo.to_torch_dtype()
if torch_dtype is not None:
return f"dtype={torch_dtype}"
exponent_bits, mantissa_bits, finite, unsigned_zero = self._finfo
return " ".join([
f"exponent_bits={exponent_bits}",
f"mantissa_bits={mantissa_bits}",
f"finite={finite}",
f"unsigned_zero={unsigned_zero}",
])
class QuantizeDequantize(FloatQuantizeDequantize):
r"""
Alias of FloatQuantizeDequantize
"""