FloatQuantizeDequantize
- class aimet_torch.v2.quantization.float.FloatQuantizeDequantize(exponent_bits=None, mantissa_bits=None, dtype=None, encoding_analyzer=None)[source]
Simulates quantization by fake-casting the input
If dtype is provided, this is equivalent to
\[\begin{split}out = x.to(dtype).to(x.dtype) \\\end{split}\]If the exponent and mantissa bits are provided, this is equivalent to
\[out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale\]where
\[\begin{split}x_c &= clamp(x, -max, max) \\ bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\ scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\\end{split}\]The IEEE standard computes the maximum representable value by
\[\begin{split}max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\\end{split}\]where
\[\begin{split}exponent\_max = 2^{exponent} - 1 \\\end{split}\]- Parameters:
exponent_bits (int) – Number of exponent bits to simulate
mantissa_bits (int) – Number of mantissa bits to simulate
dtype (torch.dtype) – torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits.
encoding_analyzer (EncodingAnalyzer) – If specified, the maximum value to represent will be determined dynamically based on the input statistics for finer precision.
Examples
>>> import aimet_torch.v2.quantization as Q >>> input = torch.tensor([[ 1.8998, -0.0947],[-1.0891, -0.1727]]) >>> qdq = Q.float.FloatQuantizeDequantize(mantissa_bits=7, exponent_bits=8) >>> # Unlike AffineQuantizer, FloatQuantizer is initialized without calling compute_encodings() >>> qdq.is_initialized() True >>> qdq.is_bfloat16() True >>> qdq.bitwidth 16 >>> qdq(input) tensor([[ 1.8984, -0.0947], [-1.0859, -0.1729]])
>>> from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer >>> encoding_analyzer = MinMaxEncodingAnalyzer(shape=[]) >>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float16, encoding_analyzer=encoding_analyzer) >>> qdq.is_float16() True >>> qdq.bitwidth 16 >>> qdq(input) tensor([[ 1.8994, -0.0947], [-1.0889, -0.1727]])