FloatQuantizeDequantize¶
- class aimet_torch.quantization.float.FloatQuantizeDequantize(exponent_bits=None, mantissa_bits=None, finite=None, unsigned_zero=None, dtype=None, encoding_analyzer=None)[source]¶
 Simulates quantization by fake-casting the input
If dtype is provided, this is equivalent to
\[\begin{split}out = x.to(dtype).to(x.dtype) \\\end{split}\]If the exponent and mantissa bits are provided, this is equivalent to
\[out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale\]where
\[\begin{split}x_c &= clamp(x, -max, max) \\ bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\ scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\\end{split}\]The IEEE standard computes the maximum representable value by
\[\begin{split}max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\\end{split}\]where
\[\begin{split}exponent\_max = 2^{exponent} - 1 \\\end{split}\]- Parameters:
 exponent_bits (int) – Number of exponent bits to simulate
mantissa_bits (int) – Number of mantissa bits to simulate
dtype (torch.dtype) – torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits.
encoding_analyzer (EncodingAnalyzer) – If specified, the maximum value to represent will be determined dynamically based on the input statistics for finer precision.
Examples
>>> import aimet_torch.v2.quantization as Q >>> input = torch.tensor([[ 1.8998, -0.0947],[-1.0891, -0.1727]]) >>> qdq = Q.float.FloatQuantizeDequantize(mantissa_bits=7, exponent_bits=8) >>> # Unlike AffineQuantizer, FloatQuantizer is initialized without calling compute_encodings() >>> qdq.is_initialized() True >>> qdq.is_bfloat16() True >>> qdq.bitwidth 16 >>> qdq(input) tensor([[ 1.8984, -0.0947], [-1.0859, -0.1729]])
>>> from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer >>> encoding_analyzer = MinMaxEncodingAnalyzer(shape=[]) >>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float16, encoding_analyzer=encoding_analyzer) >>> qdq.is_float16() True >>> qdq.bitwidth 16 >>> qdq(input) tensor([[ 1.8994, -0.0947], [-1.0889, -0.1727]])
- property bitwidth¶
 Returns bitwidth of the quantizer
- compute_encodings()[source]¶
 Observe inputs and update quantization parameters based on the input statistics. During
compute_encodingsis enabled, the quantizer forward pass performs dynamic quantization using the batch statistics.
- property exponent_bits¶
 Returns exponent bits
- forward(input)[source]¶
 - Parameters:
 input (
Tensor) – Input to quantize and dequantize- Returns:
 Quantize-dequantized output
- classmethod from_encodings(encodings)[source]¶
 Create quantizer object from encoding object
- Return type:
 
- get_encodings()[source]¶
 Return the quantizer’s encodings as an EncodingBase object
- Return type:
 Optional[FloatEncoding]
- load_state_dict(state_dict, strict=True)[source]¶
 Copies parameters and buffers from
state_dictinto this module and its descendants. IfstrictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dict.- Parameters:
 state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When
False, the properties of the tensors in the current module are preserved while whenTrue, the properties of the Tensors in the state dict are preserved. Default:False
- Returns:
 missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
 NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- property mantissa_bits¶
 Returns mantissa bits