FloatQuantizeDequantize

class aimet_torch.v2.quantization.float.FloatQuantizeDequantize(exponent_bits=None, mantissa_bits=None, dtype=None, encoding_analyzer=None)[source]

Simulates quantization by fake-casting the input

If dtype is provided, this is equivalent to

\[\begin{split}out = x.to(dtype).to(x.dtype) \\\end{split}\]

If the exponent and mantissa bits are provided, this is equivalent to

\[out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale\]

where

\[\begin{split}x_c &= clamp(x, -max, max) \\ bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\ scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\\end{split}\]

The IEEE standard computes the maximum representable value by

\[\begin{split}max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\\end{split}\]

where

\[\begin{split}exponent\_max = 2^{exponent} - 1 \\\end{split}\]
Parameters
  • exponent_bits (int) – Number of exponent bits to simulate

  • mantissa_bits (int) – Number of mantissa bits to simulate

  • dtype (torch.dtype) – torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits.

  • encoding_analyzer (EncodingAnalyzer) – If specified, the maximum value to represent will be determined dynamically based on the input statistics for finer precision.

Examples

>>> import aimet_torch.v2.quantization as Q
>>> input = torch.tensor([[ 1.8998, -0.0947],[-1.0891, -0.1727]])
>>> qdq = Q.float.FloatQuantizeDequantize(mantissa_bits=7, exponent_bits=8)
>>> # Unlike AffineQuantizer, FloatQuantizer is initialized without calling compute_encodings()
>>> qdq.is_initialized()
True
>>> qdq.is_bfloat16()
True
>>> qdq.bitwidth
16
>>> qdq(input)
tensor([[ 1.8984, -0.0947], [-1.0859, -0.1729]])
>>> from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer
>>> encoding_analyzer = MinMaxEncodingAnalyzer(shape=(1,))
>>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float16, encoding_analyzer=encoding_analyzer)
>>> qdq.is_float16()
True
>>> qdq.bitwidth
16
>>> qdq(input)
tensor([[ 1.8994, -0.0947], [-1.0889, -0.1727]])

QuantizeDequantize

class aimet_torch.v2.quantization.float.QuantizeDequantize(exponent_bits=None, mantissa_bits=None, dtype=None, encoding_analyzer=None)[source]

Alias of FloatQuantizeDequantize