FloatQuantizeDequantize

class aimet_torch.quantization.float.FloatQuantizeDequantize(exponent_bits=None, mantissa_bits=None, finite=None, unsigned_zero=None, dtype=None, shape=None, block_size=None, encoding_analyzer=None)[source]

Simulates quantization by fake-casting the input

If dtype is provided, this is equivalent to

\[\begin{split}out = x.to(dtype).to(x.dtype) \\\end{split}\]

If the exponent and mantissa bits are provided, this is equivalent to

\[out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale\]

where

\[\begin{split}x_c &= clamp(x, -max, max) \\ bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\ scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\\end{split}\]

The IEEE standard computes the maximum representable value by

\[\begin{split}max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\\end{split}\]

where

\[\begin{split}exponent\_max = 2^{exponent} - 1 \\\end{split}\]
Parameters:
  • exponent_bits (int) – Number of exponent bits to simulate. This argument is mutually exclusive with dtype.

  • mantissa_bits (int) – Number of mantissa bits to simulate. This argument is mutually exclusive with dtype.

  • finite (bool, optional) – If True, +/-inf is representable. Defaults to False. Ignored when dtype is specified.

  • unsigned_zero (bool, optional) – If False, +/-0 is representable. Defaults to True. Ignored when dtype is specified.

  • dtype (torch.dtype) – torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits.

  • shape (tuple, optional) – Shape of quantization scales. Defaults to () (= per-tensor quantization).

  • block_size (tuple, optional) – If specified, block-wise quantization is performed with the given block size.

  • encoding_analyzer (EncodingAnalyzer, optional) – If specified, quantization scale will be calibrated dynamically based on the input statistics. If not specified, sub-16-bit floating point quantizers will use min-max encoding analyzer for scale calibration; 16-bit or higher quantizers will be fixed at scale=1.0

Examples

>>> import aimet_torch.quantization as Q
>>> input = torch.tensor([[ 1.8998, -0.0947, -1.0891, -0.1727]])
>>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float8_e4m3fnuz)
>>> with qdq.compute_encodings():
...     _ = qdq(input)
...
>>> qdq(input)
DequantizedTensor([[ 1.8998, -0.0950, -1.1399, -0.1741]])
property bitwidth

Returns bitwidth of the quantizer

compute_encodings()[source]

Observe inputs and update quantization parameters based on the input statistics. During compute_encodings is enabled, the quantizer forward pass performs dynamic quantization using the batch statistics.

property exponent_bits

Returns exponent bits

forward(input)[source]
Parameters:

input (Tensor) – Input to quantize and dequantize

Returns:

Quantize-dequantized output

classmethod from_encodings(encodings)[source]

Create quantizer object from encoding object

Return type:

FloatQuantizeDequantize

get_encodings()[source]

Return the quantizer’s encodings as an EncodingBase object

Return type:

Optional[FloatEncoding]

get_extra_state()[source]

Get extra state that describes which parameters are initialized.

is_bfloat16()[source]

Returns true if current configuration simulates bfloat16

is_float16()[source]

Returns true if current configuration simulates IEEE float16

load_state_dict(state_dict, *args, **kwargs)[source]

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – When set to False, the properties of the tensors in the current module are preserved whereas setting it to True preserves properties of the Tensors in the state dict. The only exception is the requires_grad field of Parameter for which the value from the module is preserved. Default: False

Returns:

  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

property mantissa_bits

Returns mantissa bits

set_extra_state(state)[source]

Set extra state that describes which parameters are initialized.