Source code for aimet_torch.v2.quantization.float.quantizer

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=redefined-builtin
""" Float quantizers """

import contextlib
import functools
from typing import Dict, List, Optional
import math

import torch
from aimet_torch.v2.quantization.encoding_analyzer import EncodingAnalyzer, _flag_extreme_min_max
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.float import FloatEncoding
from aimet_torch.v2.quantization.tensor import DequantizedTensor
from aimet_torch.v2.utils import StatisticsNotFoundError, patch_attr
from aimet_torch.fp_quantization import fake_cast_to_ieee_float
from ._finfo import _finfo, _torch_dtype_to_finfo


__all__ = ['QuantizeDequantize', 'FloatQuantizeDequantize']


def _ieee_float_max_representable_value(exponent_bits, mantissa_bits):
    exponent_max = 2 ** exponent_bits - 1
    exponent_bias = exponent_max // 2
    return (2 - 2**-mantissa_bits) * 2 ** (exponent_max - exponent_bias - 1)


[docs] class FloatQuantizeDequantize(QuantizerBase): # pylint: disable=abstract-method r""" Simulates quantization by fake-casting the input If dtype is provided, this is equivalent to .. math:: out = x.to(dtype).to(x.dtype) \\ If the exponent and mantissa bits are provided, this is equivalent to .. math:: out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale where .. math:: x_c &= clamp(x, -max, max) \\ bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\ scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\ The IEEE standard computes the maximum representable value by .. math:: max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\ where .. math:: exponent\_max = 2^{exponent} - 1 \\ Args: exponent_bits (int): Number of exponent bits to simulate mantissa_bits (int): Number of mantissa bits to simulate dtype (torch.dtype): torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits. encoding_analyzer (EncodingAnalyzer): If specified, the maximum value to represent will be determined dynamically based on the input statistics for finer precision. Examples: >>> import aimet_torch.v2.quantization as Q >>> input = torch.tensor([[ 1.8998, -0.0947],[-1.0891, -0.1727]]) >>> qdq = Q.float.FloatQuantizeDequantize(mantissa_bits=7, exponent_bits=8) >>> # Unlike AffineQuantizer, FloatQuantizer is initialized without calling compute_encodings() >>> qdq.is_initialized() True >>> qdq.is_bfloat16() True >>> qdq.bitwidth 16 >>> qdq(input) tensor([[ 1.8984, -0.0947], [-1.0859, -0.1729]]) >>> from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer >>> encoding_analyzer = MinMaxEncodingAnalyzer(shape=[]) >>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float16, encoding_analyzer=encoding_analyzer) >>> qdq.is_float16() True >>> qdq.bitwidth 16 >>> qdq(input) tensor([[ 1.8994, -0.0947], [-1.0889, -0.1727]]) """ maxval: Optional[torch.Tensor] def __init__(self, exponent_bits: Optional[int] = None, mantissa_bits: Optional[int] = None, finite: Optional[bool] = None, unsigned_zero: Optional[bool] = None, dtype: Optional[torch.dtype] = None, encoding_analyzer: Optional[EncodingAnalyzer] = None): super().__init__() if dtype is None: if exponent_bits is None or mantissa_bits is None: raise ValueError('Neither "dtype" nor "exponent/mantissa_bits" was specified.') if finite is None: self.finite = False if unsigned_zero is None: self.unsigned_zero = False if dtype is not None: if exponent_bits is not None or mantissa_bits is not None or \ finite is not None or unsigned_zero is not None: raise ValueError( 'Argument "dtype" is mutually exclusive with "exponent/mantissa_bits/finite/unsigned_zero".') exponent_bits, mantissa_bits, finite, unsigned_zero = _finfo.from_torch_dtype(dtype) self._finfo = _finfo(exponent_bits, mantissa_bits, finite, unsigned_zero) self.encoding_analyzer = encoding_analyzer if self.encoding_analyzer: shape = self.encoding_analyzer.observer.shape maxval = _ieee_float_max_representable_value(exponent_bits, mantissa_bits) self.register_buffer('maxval', torch.full(shape, maxval)) else: self.register_buffer('maxval', None) self._assert_supported_dtype() def _assert_supported_dtype(self): if self._finfo.finite or self._finfo.unsigned_zero: if self._finfo.to_torch_dtype() is None: torch_special_builtin_dtypes = [ dtype for dtype in _torch_dtype_to_finfo if dtype not in (torch.float16, torch.bfloat16) ] msg = " ".join([ "finite/unsigned_zero floating point has limited support.", f"Expected PyTorch built-in data types, such as {torch_special_builtin_dtypes};", f"got '{self._finfo.to_str()}'", ]) raise RuntimeError(msg) if self.maxval is not None: msg = " ".join([ "finite/unsigned_zero floating point has limited support.", "Expected 'maxval' to be None", ]) raise RuntimeError(msg) @property def exponent_bits(self): """ Returns exponent bits """ return self._finfo.exponent_bits @exponent_bits.setter def exponent_bits(self, exponent_bits: int): _, mantissa_bits, finite, unsigned_zero = self._finfo self._finfo = _finfo(exponent_bits, mantissa_bits, finite, unsigned_zero) @property def mantissa_bits(self): """ Returns mantissa bits """ return self._finfo.mantissa_bits @mantissa_bits.setter def mantissa_bits(self, mantissa_bits: int): exponent_bits, _, finite, unsigned_zero = self._finfo self._finfo = _finfo(exponent_bits, mantissa_bits, finite, unsigned_zero)
[docs] def get_extra_state(self): extra_state_dict = super().get_extra_state() extra_state_dict['exponent_bits'] = torch.tensor(self.exponent_bits) extra_state_dict['mantissa_bits'] = torch.tensor(self.mantissa_bits) return extra_state_dict
[docs] def set_extra_state(self, state): self.exponent_bits = state['exponent_bits'].item() self.mantissa_bits = state['mantissa_bits'].item() super().set_extra_state(state)
[docs] def load_state_dict(self, state_dict, strict: bool = True): if 'maxval' in state_dict: if self.maxval is None: del self.maxval self.register_buffer('maxval', state_dict['maxval']) elif self.maxval is not None: del self.maxval self.register_buffer('maxval', None) ret = super().load_state_dict(state_dict, strict) return ret
@property def bitwidth(self): """ Returns bitwidth of the quantizer """ return self.exponent_bits + self.mantissa_bits + 1
[docs] def is_float16(self): """ Returns true if current configuration simulates IEEE float16 """ return self._finfo.is_float16()
[docs] def is_bfloat16(self): """ Returns true if current configuration simulates bfloat16 """ return self._finfo.is_bfloat16()
def get_legacy_encodings(self) -> Optional[List[Dict]]: """ :meta private: """ return [{'bitwidth': self.bitwidth, 'dtype': 'float'}] def set_legacy_encodings(self, encodings: List[Dict]): """ :meta private: Set encodings represented in the same format as the output of get_legacy_encodings as below: [ {'bitwidth': int, 'dtype': str}, ... ] """ if encodings[0]['bitwidth'] != 16: raise RuntimeError(f"{self.__class__} can only import 16-bit legay encodings.") self.exponent_bits = 5 self.mantissa_bits = 10
[docs] def get_encodings(self) -> Optional[FloatEncoding]: if self.is_initialized(): return FloatEncoding(self._finfo.mantissa_bits, self._finfo.exponent_bits, self._finfo.finite, self._finfo.unsigned_zero, self.maxval) return None
[docs] @classmethod def from_encodings(cls, encodings: FloatEncoding) -> "FloatQuantizeDequantize": if not isinstance(encodings, FloatEncoding): raise TypeError(f"Expected {FloatEncoding}; got {type(encodings)}") qtzr = cls(exponent_bits=encodings.exponent_bits, mantissa_bits=encodings.mantissa_bits) if encodings.maxval is not None: qtzr.maxval.copy_(encodings.maxval) return qtzr
[docs] @contextlib.contextmanager def compute_encodings(self): """ Observe inputs and update quantization parameters based on the input statistics. During ``compute_encodings`` is enabled, the quantizer forward pass performs dynamic quantization using the batch statistics. """ if not self.encoding_analyzer or not self._allow_overwrite: yield return original_forward = self.forward @functools.wraps(original_forward) def forward_wrapper(input): input = input.as_subclass(torch.Tensor) batch_statistics = self.encoding_analyzer.update_stats(input) num_steps = math.pow(2, self.bitwidth) - 1 dynamic_min, dynamic_max =\ self.encoding_analyzer.compute_encodings_from_stats(batch_statistics, num_steps, is_symmetric=False) dynamic_absmax = torch.maximum(dynamic_min.abs(), dynamic_max.abs()) dynamic_absmax = dynamic_absmax.to(dtype=self.maxval.dtype, device=self.maxval.device).expand_as(self.maxval) with patch_attr(self, 'maxval', dynamic_absmax): return original_forward(input) self.encoding_analyzer.reset_stats() try: with patch_attr(self, 'forward', forward_wrapper): yield except: # pylint: disable=try-except-raise raise try: num_steps = math.pow(2, self.bitwidth) - 1 min, max = self.encoding_analyzer.compute_encodings(num_steps, is_symmetric=False) _flag_extreme_min_max(min, max) except StatisticsNotFoundError: return if min is None or max is None: return absmax = torch.maximum(min.abs(), max.abs()).expand_as(self.maxval) with torch.no_grad(): self.maxval.copy_(absmax)
[docs] def forward(self, input: torch.Tensor): """ :param input: Input to quantize and dequantize :return: Quantize-dequantized output """ self._assert_supported_dtype() if not self.is_initialized(): raise RuntimeError( 'Failed to run FloatQuantizeDequantize since quantization parameters are not initialized.' ' Please initialize the quantization parameters using `compute_encodings()`.' ) encoding = self.get_encodings() assert encoding is not None maxval = encoding.maxval exponent_bits = encoding.exponent_bits mantissa_bits = encoding.mantissa_bits target_torch_dtype = self._finfo.to_torch_dtype() if maxval is None and target_torch_dtype is not None: # Fast forward using type casting orig_dtype = input.dtype output = input.to(target_torch_dtype).to(orig_dtype) else: if maxval is None: maxval = _ieee_float_max_representable_value(exponent_bits, mantissa_bits) # Subclasses of torch.Tensor with custom __torch_function__ (in our case, QuantizedTensorBase) # is known to introduce substantial CPU overhead. # Cast types of the inputs to plain torch.Tensor for faster execution. output = fake_cast_to_ieee_float(input.as_subclass(torch.Tensor), maxval, exponent_bits, mantissa_bits) output = output.as_subclass(DequantizedTensor) output.encoding = encoding return output
def extra_repr(self): """ :meta private: """ if self.maxval is None: torch_dtype = self._finfo.to_torch_dtype() if torch_dtype is not None: return f"dtype={torch_dtype}" exponent_bits, mantissa_bits, finite, unsigned_zero = self._finfo return " ".join([ f"exponent_bits={exponent_bits}", f"mantissa_bits={mantissa_bits}", f"finite={finite}", f"unsigned_zero={unsigned_zero}", ])
class QuantizeDequantize(FloatQuantizeDequantize): r""" Alias of FloatQuantizeDequantize """