FloatQuantizeDequantize

class aimet_torch.quantization.float.FloatQuantizeDequantize(exponent_bits=None, mantissa_bits=None, finite=None, unsigned_zero=None, dtype=None, encoding_analyzer=None)[source]

Simulates quantization by fake-casting the input

If dtype is provided, this is equivalent to

\[\begin{split}out = x.to(dtype).to(x.dtype) \\\end{split}\]

If the exponent and mantissa bits are provided, this is equivalent to

\[out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale\]

where

\[\begin{split}x_c &= clamp(x, -max, max) \\ bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\ scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\\end{split}\]

The IEEE standard computes the maximum representable value by

\[\begin{split}max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\\end{split}\]

where

\[\begin{split}exponent\_max = 2^{exponent} - 1 \\\end{split}\]
Parameters:
  • exponent_bits (int) – Number of exponent bits to simulate

  • mantissa_bits (int) – Number of mantissa bits to simulate

  • dtype (torch.dtype) – torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits.

  • encoding_analyzer (EncodingAnalyzer) – If specified, the maximum value to represent will be determined dynamically based on the input statistics for finer precision.

Examples

>>> import aimet_torch.v2.quantization as Q
>>> input = torch.tensor([[ 1.8998, -0.0947],[-1.0891, -0.1727]])
>>> qdq = Q.float.FloatQuantizeDequantize(mantissa_bits=7, exponent_bits=8)
>>> # Unlike AffineQuantizer, FloatQuantizer is initialized without calling compute_encodings()
>>> qdq.is_initialized()
True
>>> qdq.is_bfloat16()
True
>>> qdq.bitwidth
16
>>> qdq(input)
tensor([[ 1.8984, -0.0947], [-1.0859, -0.1729]])
>>> from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer
>>> encoding_analyzer = MinMaxEncodingAnalyzer(shape=[])
>>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float16, encoding_analyzer=encoding_analyzer)
>>> qdq.is_float16()
True
>>> qdq.bitwidth
16
>>> qdq(input)
tensor([[ 1.8994, -0.0947], [-1.0889, -0.1727]])
property bitwidth

Returns bitwidth of the quantizer

compute_encodings()[source]

Observe inputs and update quantization parameters based on the input statistics. During compute_encodings is enabled, the quantizer forward pass performs dynamic quantization using the batch statistics.

property exponent_bits

Returns exponent bits

forward(input)[source]
Parameters:

input (Tensor) – Input to quantize and dequantize

Returns:

Quantize-dequantized output

classmethod from_encodings(encodings)[source]

Create quantizer object from encoding object

Return type:

FloatQuantizeDequantize

get_encodings()[source]

Return the quantizer’s encodings as an EncodingBase object

Return type:

Optional[FloatEncoding]

get_extra_state()[source]

Get extra state that describes which parameters are initialized.

is_bfloat16()[source]

Returns true if current configuration simulates bfloat16

is_float16()[source]

Returns true if current configuration simulates IEEE float16

load_state_dict(state_dict, strict=True)[source]

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – When set to False, the properties of the tensors in the current module are preserved whereas setting it to True preserves properties of the Tensors in the state dict. The only exception is the requires_grad field of Parameter for which the value from the module is preserved. Default: False

Returns:

  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

property mantissa_bits

Returns mantissa bits

set_extra_state(state)[source]

Set extra state that describes which parameters are initialized.