FloatQuantizeDequantize¶
- class aimet_torch.quantization.float.FloatQuantizeDequantize(exponent_bits=None, mantissa_bits=None, dtype=None, encoding_analyzer=None)[source]¶
Simulates quantization by fake-casting the input
If dtype is provided, this is equivalent to
\[\begin{split}out = x.to(dtype).to(x.dtype) \\\end{split}\]If the exponent and mantissa bits are provided, this is equivalent to
\[out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale\]where
\[\begin{split}x_c &= clamp(x, -max, max) \\ bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\ scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\\end{split}\]The IEEE standard computes the maximum representable value by
\[\begin{split}max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\\end{split}\]where
\[\begin{split}exponent\_max = 2^{exponent} - 1 \\\end{split}\]- Parameters:
exponent_bits (int) – Number of exponent bits to simulate
mantissa_bits (int) – Number of mantissa bits to simulate
dtype (torch.dtype) – torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits.
encoding_analyzer (EncodingAnalyzer) – If specified, the maximum value to represent will be determined dynamically based on the input statistics for finer precision.
Examples
>>> import aimet_torch.v2.quantization as Q >>> input = torch.tensor([[ 1.8998, -0.0947],[-1.0891, -0.1727]]) >>> qdq = Q.float.FloatQuantizeDequantize(mantissa_bits=7, exponent_bits=8) >>> # Unlike AffineQuantizer, FloatQuantizer is initialized without calling compute_encodings() >>> qdq.is_initialized() True >>> qdq.is_bfloat16() True >>> qdq.bitwidth 16 >>> qdq(input) tensor([[ 1.8984, -0.0947], [-1.0859, -0.1729]])
>>> from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer >>> encoding_analyzer = MinMaxEncodingAnalyzer(shape=[]) >>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float16, encoding_analyzer=encoding_analyzer) >>> qdq.is_float16() True >>> qdq.bitwidth 16 >>> qdq(input) tensor([[ 1.8994, -0.0947], [-1.0889, -0.1727]])
- property bitwidth¶
Returns bitwidth of the quantizer
- compute_encodings()[source]¶
Observe inputs and update quantization parameters based on the input statistics. During
compute_encodings
is enabled, the quantizer forward pass performs dynamic quantization using the batch statistics.
- forward(input)[source]¶
- Parameters:
input (
Tensor
) – Input to quantize and dequantize- Returns:
Quantize-dequantized output
- get_encodings()[source]¶
Return the quantizer’s encodings as an EncodingBase object
- Return type:
Optional
[FloatEncoding
]
- load_state_dict(state_dict, strict=True)[source]¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. Default:False
- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.